diff --git a/.github/workflows/health_check.yml b/.github/workflows/health_check.yml index 611e6c0a17..91fe81ed27 100644 --- a/.github/workflows/health_check.yml +++ b/.github/workflows/health_check.yml @@ -19,7 +19,7 @@ jobs: - name: Installing dependencies run: | pip install tensorflow pytest pytest-cov - pip install -e ./ + pip install -e .[default] - name: Code instrumentation run: | pytest -v --cov --cov-report xml:coverage.xml diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index 6191c7bad1..4bc03606e7 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -31,7 +31,7 @@ jobs: - name: Installing dependencies run: | pip install tensorflow pytest - pip install -e ./ + pip install -e .[default] - name: Unit testing run: | pytest -v diff --git a/CHANGELOG.md b/CHANGELOG.md index ca487c77c0..8280b84be4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,20 +8,55 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## \[Unreleased\] ### Added -- TBD +- A new installation target: `pip install datumaro[default]`, which should + be used by default. The simple `datumaro` is supposed for library users. + () +- Dataset and project versioning capabilities (Git-like) + () +- "dataset revpath" concept in CLI, allowing to pass a dataset path with + the dataset format in `diff`, `merge`, `explain` and `info` CLI commands + () +- `add`, `remove`, `commit`, `checkout`, `log`, `status`, `info` CLI commands + () ### Changed +- A project can contain and manage multiple datasets instead of a single one. + CLI operations can be applied to the whole project, or to separate datasets. + Datasets are modified inplace, by default + () +- CLI help for builtin plugins doesn't require project + () - Annotation-related classes were moved into a new module, `datumaro.components.annotation` () - Rollback utilities replaced with Scope utilities () +- The `Project` class from `datumaro.components` is changed completely + () +- `diff` and `ediff` are joined into a single `diff` CLI command + () +- Projects use new file layout, incompatible with old projects. + An old project can be updated with `datum project migrate` + () +- Inheriting `CliPlugin` is not required in plugin classes + () +- `Importer`s do not create `Project`s anymore and just return a list of + extractor configurations + () ### Deprecated - TBD ### Removed -- TBD +- `import`, `project merge` CLI commands + () +- Support for project hierarchies. A project cannot be a source anymore + () +- Project cannot have independent internal dataset anymore. All the project + data must be stored in the project data sources + () +- `datumaro_project` format + () ### Fixed - Deprecation warning in `open_images_format.py` @@ -71,7 +106,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Calling `ProjectDataset.transform()` with a string argument () - Attributes casting for CVAT format () - Loading of custom project plugins () -- Reading, writing anno file and saving name of the subset for test subset +- Reading, writing anno file and saving name of the subset for test subset () ### Security diff --git a/MANIFEST.in b/MANIFEST.in index 57775223d2..f717cef33b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include README.md -include requirements-core.txt \ No newline at end of file +include requirements-core.txt +include requirements-default.txt diff --git a/datumaro/cli/__main__.py b/datumaro/cli/__main__.py index 189f956317..4d8e0b640d 100644 --- a/datumaro/cli/__main__.py +++ b/datumaro/cli/__main__.py @@ -9,7 +9,8 @@ from ..version import VERSION from . import commands, contexts -from .util import CliException, add_subparser +from .util import add_subparser +from .util.errors import CliException _log_levels = { 'debug': log.DEBUG, @@ -59,22 +60,31 @@ def make_parser(): _LogManager._define_loglevel_option(parser) known_contexts = [ - ('project', contexts.project, "Actions with project (deprecated)"), + ('project', contexts.project, "Actions with project"), ('source', contexts.source, "Actions with data sources"), ('model', contexts.model, "Actions with models"), ] known_commands = [ - ('create', commands.create, "Create project"), - ('import', commands.import_, "Create project from existing dataset"), + ("Project modification:", None, ''), + ('create', commands.create, "Create empty project"), ('add', commands.add, "Add data source to project"), ('remove', commands.remove, "Remove data source from project"), + + ("", None, ''), + ("Project versioning:", None, ''), + ('checkout', commands.checkout, "Switch to another branch or revision"), + ('commit', commands.commit, "Commit changes in tracked files"), + ('log', commands.log, "List history"), + ('status', commands.status, "Display current status"), + + ("", None, ''), + ("Dataset and project operations:", None, ''), ('export', commands.export, "Export project in some format"), - ('filter', commands.filter, "Filter project"), - ('transform', commands.transform, "Transform project"), - ('merge', commands.merge, "Merge projects"), - ('convert', commands.convert, "Convert dataset into another format"), - ('diff', commands.diff, "Compare projects with intersection"), - ('ediff', commands.ediff, "Compare projects for equality"), + ('filter', commands.filter, "Filter project items"), + ('transform', commands.transform, "Modify project items"), + ('merge', commands.merge, "Merge datasets"), + ('convert', commands.convert, "Convert dataset between formats"), + ('diff', commands.diff, "Compare datasets"), ('stats', commands.stats, "Compute project statistics"), ('info', commands.info, "Print project info"), ('explain', commands.explain, "Run Explainable AI algorithm for model"), @@ -105,7 +115,8 @@ def make_parser(): subcommands = parser.add_subparsers(title=subcommands_desc, description="", help=argparse.SUPPRESS) for command_name, command, _ in known_contexts + known_commands: - add_subparser(subcommands, command_name, command.build_parser) + if command is not None: + add_subparser(subcommands, command_name, command.build_parser) return parser @@ -121,7 +132,10 @@ def main(args=None): return 1 try: - return args.command(args) + retcode = args.command(args) + if retcode is None: + retcode = 0 + return retcode except CliException as e: log.error(e) return 1 diff --git a/datumaro/cli/commands/__init__.py b/datumaro/cli/commands/__init__.py index febb60775a..b96b839e8c 100644 --- a/datumaro/cli/commands/__init__.py +++ b/datumaro/cli/commands/__init__.py @@ -5,6 +5,6 @@ # pylint: disable=redefined-builtin from . import ( - add, convert, create, diff, ediff, explain, export, filter, import_, info, - merge, remove, stats, transform, validate, + add, checkout, commit, convert, create, diff, explain, export, filter, info, + log, merge, remove, stats, status, transform, validate, ) diff --git a/datumaro/cli/commands/checkout.py b/datumaro/cli/commands/checkout.py new file mode 100644 index 0000000000..ce84d8b6d4 --- /dev/null +++ b/datumaro/cli/commands/checkout.py @@ -0,0 +1,70 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from datumaro.util.scope import scope_add, scoped + +from ..util import MultilineFormatter +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Navigate to a revision", + description=""" + Command forms:|n + 1) %(prog)s |n + 2) %(prog)s [--] ...|n + 3) %(prog)s [--] ...|n + |n + 1 - Restores a revision and all the sources in the working directory.|n + 2, 3 - Restores only specified sources from the specified revision.|n + |s|sThe current revision is used, when not set.|n + |s|s"--" is optionally used to separate source names and revisions.|n + |n + Examples:|n + - Restore the previous revision:|n + |s|s%(prog)s HEAD~1 |n |n + - Restore the saved version of a source in the working tree|n + |s|s%(prog)s -- source-1 |n |n + - Restore a previous version of a source|n + |s|s%(prog)s 33fbfbe my-source + """, formatter_class=MultilineFormatter) + + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # workaround for -- eaten by positionals + parser.add_argument('rev', nargs='?', + help="Commit or tag (default: current)") + parser.add_argument('sources', nargs='*', + help="Sources to restore (default: all)") + parser.add_argument('-f', '--force', action='store_true', + help="Allows to overwrite unsaved changes in case of conflicts " + "(default: %(default)s)") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=checkout_command) + + return parser + +@scoped +def checkout_command(args): + has_sep = '--' in args._positionals + if has_sep: + pos = args._positionals.index('--') + if 1 < pos: + raise argparse.ArgumentError(None, + message="Expected no more than 1 revision argument") + else: + pos = 1 + args.rev = (args._positionals[:pos] or [None])[0] + args.sources = args._positionals[pos + has_sep:] + if has_sep and not args.sources: + raise argparse.ArgumentError('sources', message="When '--' is used, " + "at least 1 source name must be specified") + + project = scope_add(load_project(args.project_dir)) + + project.checkout(rev=args.rev, sources=args.sources, force=args.force) + + return 0 diff --git a/datumaro/cli/commands/commit.py b/datumaro/cli/commands/commit.py new file mode 100644 index 0000000000..477a1d283a --- /dev/null +++ b/datumaro/cli/commands/commit.py @@ -0,0 +1,54 @@ +# Copyright (C) 2020-2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from datumaro.util.scope import scope_add, scoped + +from ..util import MultilineFormatter +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Create a revision", + description=""" + Creates a new revision from the current state of the working directory.|n + |n + Examples:|n + - Create a revision:|n + |s|s%(prog)s -m "Added data" + """, formatter_class=MultilineFormatter) + + parser.add_argument('-m', '--message', required=True, help="Commit message") + parser.add_argument('--allow-empty', action='store_true', + help="Allow commits with no changes (default: %(default)s)") + parser.add_argument('--allow-foreign', action='store_true', + help="Allow commits with non-Datumaro changes (default: %(default)s)") + parser.add_argument('--no-cache', action='store_true', + help="Don't put committed datasets into cache, " + "save only metainfo (default: %(default)s)") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=commit_command) + + return parser + +@scoped +def commit_command(args): + project = scope_add(load_project(args.project_dir)) + + old_tree = project.head + + new_commit = project.commit(args.message, allow_empty=args.allow_empty, + allow_foreign=args.allow_foreign, no_cache=args.no_cache) + + new_tree = project.working_tree + diff = project.diff(old_tree, new_tree) + + print("Moved to commit '%s' %s" % (new_commit, args.message)) + print(" %s targets changed" % len(diff)) + for t, s in diff.items(): + print(" %s %s" % (s.name, t)) + + return 0 diff --git a/datumaro/cli/commands/convert.py b/datumaro/cli/commands/convert.py index f21f04d7d1..2c39d27658 100644 --- a/datumaro/cli/commands/convert.py +++ b/datumaro/cli/commands/convert.py @@ -12,34 +12,41 @@ from datumaro.util.os_util import make_file_name from ..contexts.project import FilterModes -from ..util import CliException, MultilineFormatter +from ..util import MultilineFormatter +from ..util.errors import CliException from ..util.project import generate_next_file_name def build_parser(parser_ctor=argparse.ArgumentParser): - builtin_importers = sorted(Environment().importers.items) - builtin_converters = sorted(Environment().converters.items) + builtin_readers = sorted( + set(Environment().importers) | set(Environment().extractors)) + builtin_writers = sorted(Environment().converters) parser = parser_ctor(help="Convert an existing dataset to another format", description=""" - Converts a dataset from one format to another. - You can add your own formats using a project.|n - |n - Supported input formats: %s|n - |n - Supported output formats: %s|n - |n - Examples:|n - - Export a dataset as a PASCAL VOC dataset, include images:|n - |s|sconvert -i src/path -f voc -- --save-images|n - |n - - Export a dataset as a COCO dataset to a specific directory:|n - |s|sconvert -i src/path -f coco -o path/I/like/ - """ % (', '.join(builtin_importers), ', '.join(builtin_converters)), + Converts a dataset from one format to another. + You can add your own formats and do many more by creating a + Datumaro project.|n + |n + This command serves as an alias for the "create", "add", and + "export" commands, allowing to obtain the same results simpler + and faster. Check descriptions of these commands for more info.|n + |n + Supported input formats: {}|n + |n + Supported output formats: {}|n + |n + Examples:|n + - Export a dataset as a PASCAL VOC dataset, include images:|n + |s|s%(prog)s -i src/path -f voc -- --save-images|n + |n + - Export a dataset as a COCO dataset to a specific directory:|n + |s|s%(prog)s -i src/path -f coco -o path/I/like/ + """.format(', '.join(builtin_readers), ', '.join(builtin_writers)), formatter_class=MultilineFormatter) parser.add_argument('-i', '--input-path', default='.', dest='source', - help="Path to look for a dataset") + help="Input dataset path (default: current dir)") parser.add_argument('-if', '--input-format', help="Input dataset format. Will try to detect, if not specified.") parser.add_argument('-f', '--output-format', required=True, @@ -49,13 +56,15 @@ def build_parser(parser_ctor=argparse.ArgumentParser): parser.add_argument('--overwrite', action='store_true', help="Overwrite existing files in the save directory") parser.add_argument('-e', '--filter', - help="Filter expression for dataset items") + help="XML XPath filter expression for dataset items. Read \"filter\" " + "command docs for more info") parser.add_argument('--filter-mode', default=FilterModes.i.name, type=FilterModes.parse, - help="Filter mode (options: %s; default: %s)" % \ + help="Filter mode, one of %s (default: %s)" % \ (', '.join(FilterModes.list_options()) , '%(default)s')) parser.add_argument('extra_args', nargs=argparse.REMAINDER, - help="Additional arguments for output format (pass '-- -h' for help)") + help="Additional arguments for output format (pass '-- -h' for help). " + "Must be specified after the main command arguments") parser.set_defaults(command=convert_command) return parser diff --git a/datumaro/cli/commands/create.py b/datumaro/cli/commands/create.py index bc3a6de7fe..c7d56683bc 100644 --- a/datumaro/cli/commands/create.py +++ b/datumaro/cli/commands/create.py @@ -2,8 +2,56 @@ # # SPDX-License-Identifier: MIT -from ..contexts.project import build_create_parser as build_parser +import argparse +import logging as log +import os +import os.path as osp -__all__ = [ - 'build_parser', -] +from datumaro.components.project import Project +from datumaro.util.os_util import rmtree + +from ..util import MultilineFormatter +from ..util.errors import CliException + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Create empty project", + description=""" + Create an empty Datumaro project. A project is required for the most of + Datumaro functionality.|n + |n + Examples:|n + - Create a project in the current directory:|n + |s|s%(prog)s|n + |n + - Create a project in other directory:|n + |s|s%(prog)s -o path/I/like/ + """, + formatter_class=MultilineFormatter) + + parser.add_argument('-o', '--output-dir', default='.', dest='dst_dir', + help="Save directory for the new project (default: current dir") + parser.add_argument('--overwrite', action='store_true', + help="Overwrite existing files in the save directory") + parser.set_defaults(command=create_command) + + return parser + +def create_command(args): + project_dir = osp.abspath(args.dst_dir) + + existing_project_dir = Project.find_project_dir(project_dir) + if existing_project_dir and os.listdir(existing_project_dir): + if args.overwrite: + rmtree(existing_project_dir) + else: + raise CliException("Directory '%s' already exists " + "(pass --overwrite to overwrite)" % existing_project_dir) + + log.info("Creating project at '%s'" % project_dir) + + Project.init(project_dir) + + log.info("Project has been created at '%s'" % project_dir) + + return 0 diff --git a/datumaro/cli/commands/diff.py b/datumaro/cli/commands/diff.py index 3c4ce32714..6f3b3443bd 100644 --- a/datumaro/cli/commands/diff.py +++ b/datumaro/cli/commands/diff.py @@ -2,8 +2,216 @@ # # SPDX-License-Identifier: MIT -from ..contexts.project import build_diff_parser as build_parser +from enum import Enum, auto +import argparse +import json +import logging as log +import os +import os.path as osp -__all__ = [ - 'build_parser', -] +from datumaro.components.errors import ProjectNotFoundError +from datumaro.components.operations import DistanceComparator, ExactComparator +from datumaro.util.os_util import rmtree +from datumaro.util.scope import on_error_do, scope_add, scoped + +from ..contexts.project.diff import DiffVisualizer +from ..util import MultilineFormatter +from ..util.errors import CliException +from ..util.project import ( + generate_next_file_name, load_project, parse_full_revpath, +) + + +class ComparisonMethod(Enum): + equality = auto() + distance = auto() + +eq_default_if = ['id', 'group'] # avoid https://bugs.python.org/issue16399 + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Compares two datasets", + description=""" + Compares two datasets. This command has multiple forms:|n + 1) %(prog)s |n + 2) %(prog)s |n + |n + 1 - Compares the current project's main target ('project') + in the working tree with the specified dataset.|n + 2 - Compares two specified datasets.|n + |n + - either a dataset path or a revision path. The full + syntax is:|n + - Dataset paths:|n + |s|s- [ : ]|n + - Revision paths:|n + |s|s- [ @ ] [ : ]|n + |s|s- [ : ]|n + |s|s- |n + |n + Both forms use the -p/--project as a context for plugins. It can be + useful for dataset paths in targets. When not specified, the current + project's working tree is used.|n + |n + Annotations can be matched 2 ways:|n + - by equality checking|n + - by distance computation|n + |n + Examples:|n + - Compare two projects by distance, match boxes if IoU > 0.7,|n + |s|s|s|ssave results to Tensorboard:|n + |s|s%(prog)s other/project -o diff/ -f tensorboard --iou-thresh 0.7|n + |n + - Compare two projects for equality, exclude annotation groups |n + |s|s|s|sand the 'is_crowd' attribute from comparison:|n + |s|s%(prog)s other/project/ -if group -ia is_crowd|n + |n + - Compare two datasets, specify formats:|n + |s|s%(prog)s path/to/dataset1:voc path/to/dataset2:coco|n + |n + - Compare the current working tree and a dataset:|n + |s|s%(prog)s path/to/dataset2:coco|n + |n + - Compare a source from a previous revision and a dataset:|n + |s|s%(prog)s HEAD~2:source-2 path/to/dataset2:yolo + """, + formatter_class=MultilineFormatter) + + formats = ', '.join(f.name for f in DiffVisualizer.OutputFormat) + comp_methods = ', '.join(m.name for m in ComparisonMethod) + + def _parse_output_format(s): + try: + return DiffVisualizer.OutputFormat[s.lower()] + except KeyError: + raise argparse.ArgumentError('format', message="Unknown output " + "format '%s', the only available are: %s" % (s, formats)) + + def _parse_comparison_method(s): + try: + return ComparisonMethod[s.lower()] + except KeyError: + raise argparse.ArgumentError('method', message="Unknown comparison " + "method '%s', the only available are: %s" % (s, comp_methods)) + + parser.add_argument('first_target', + help="The first dataset revpath to be compared") + parser.add_argument('second_target', nargs='?', + help="The second dataset revpath to be compared") + parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, + help="Directory to save comparison results " + "(default: generate automatically)") + parser.add_argument('-m', '--method', type=_parse_comparison_method, + default=ComparisonMethod.equality.name, + help="Comparison method, one of {} (default: %(default)s)" \ + .format(comp_methods)) + parser.add_argument('--overwrite', action='store_true', + help="Overwrite existing files in the save directory") + parser.add_argument('-p', '--project', dest='project_dir', + help="Directory of the current project (default: current dir)") + parser.set_defaults(command=diff_command) + + distance_parser = parser.add_argument_group("Distance comparison options") + distance_parser.add_argument('--iou-thresh', default=0.5, type=float, + help="IoU match threshold for shapes (default: %(default)s)") + parser.add_argument('-f', '--format', type=_parse_output_format, + default=DiffVisualizer.DEFAULT_FORMAT.name, + help="Output format, one of {} (default: %(default)s)".format(formats)) + + equality_parser = parser.add_argument_group("Equality comparison options") + equality_parser.add_argument('-iia', '--ignore-item-attr', action='append', + help="Ignore item attribute (repeatable)") + equality_parser.add_argument('-ia', '--ignore-attr', action='append', + help="Ignore annotation attribute (repeatable)") + equality_parser.add_argument('-if', '--ignore-field', action='append', + help="Ignore annotation field (repeatable, default: %s)" % \ + eq_default_if) + equality_parser.add_argument('--match-images', action='store_true', + help='Match dataset items by image pixels instead of ids') + equality_parser.add_argument('--all', action='store_true', + help="Include matches in the output") + + return parser + +@scoped +def diff_command(args): + dst_dir = args.dst_dir + if dst_dir: + if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): + raise CliException("Directory '%s' already exists " + "(pass --overwrite to overwrite)" % dst_dir) + else: + dst_dir = generate_next_file_name('diff') + dst_dir = osp.abspath(dst_dir) + + if not osp.exists(dst_dir): + on_error_do(rmtree, dst_dir, ignore_errors=True) + os.makedirs(dst_dir) + + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if args.project_dir: + raise + + try: + if not args.second_target: + first_dataset = project.working_tree.make_dataset() + second_dataset, target_project = \ + parse_full_revpath(args.first_target, project) + if target_project: + scope_add(target_project) + else: + first_dataset, target_project = \ + parse_full_revpath(args.first_target, project) + if target_project: + scope_add(target_project) + + second_dataset, target_project = \ + parse_full_revpath(args.second_target, project) + if target_project: + scope_add(target_project) + except Exception as e: + raise CliException(str(e)) + + if args.method is ComparisonMethod.equality: + if args.ignore_field: + args.ignore_field = eq_default_if + comparator = ExactComparator( + match_images=args.match_images, + ignored_fields=args.ignore_field, + ignored_attrs=args.ignore_attr, + ignored_item_attrs=args.ignore_item_attr) + matches, mismatches, a_extra, b_extra, errors = \ + comparator.compare_datasets(first_dataset, second_dataset) + + output = { + "mismatches": mismatches, + "a_extra_items": sorted(a_extra), + "b_extra_items": sorted(b_extra), + "errors": errors, + } + if args.all: + output["matches"] = matches + + output_file = osp.join(dst_dir, + generate_next_file_name('diff', ext='.json', basedir=dst_dir)) + log.info("Saving diff to '%s'" % output_file) + with open(output_file, 'w') as f: + json.dump(output, f, indent=4, sort_keys=True) + + print("Found:") + print("The first project has %s unmatched items" % len(a_extra)) + print("The second project has %s unmatched items" % len(b_extra)) + print("%s item conflicts" % len(errors)) + print("%s matching annotations" % len(matches)) + print("%s mismatching annotations" % len(mismatches)) + elif args.method is ComparisonMethod.distance: + comparator = DistanceComparator(iou_threshold=args.iou_thresh) + + with DiffVisualizer(save_dir=dst_dir, comparator=comparator, + output_format=args.format) as visualizer: + log.info("Saving diff to '%s'" % dst_dir) + visualizer.save(first_dataset, second_dataset) + + return 0 diff --git a/datumaro/cli/commands/ediff.py b/datumaro/cli/commands/ediff.py deleted file mode 100644 index e835d2e58f..0000000000 --- a/datumaro/cli/commands/ediff.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (C) 2019-2021 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from ..contexts.project import build_ediff_parser as build_parser - -__all__ = [ - 'build_parser', -] diff --git a/datumaro/cli/commands/explain.py b/datumaro/cli/commands/explain.py index fe5c1dd39f..b3fe2fda1c 100644 --- a/datumaro/cli/commands/explain.py +++ b/datumaro/cli/commands/explain.py @@ -7,15 +7,11 @@ import os import os.path as osp -from datumaro.components.project import Project -from datumaro.util.command_targets import ( - ImageTarget, ProjectTarget, SourceTarget, TargetKinds, is_project_path, - target_selector, -) -from datumaro.util.image import load_image, save_image +from datumaro.util.image import is_image, load_image, save_image +from datumaro.util.scope import scope_add, scoped from ..util import MultilineFormatter -from ..util.project import load_project +from ..util.project import load_project, parse_full_revpath def build_parser(parser_ctor=argparse.ArgumentParser): @@ -42,16 +38,36 @@ def build_parser(parser_ctor=argparse.ArgumentParser): - RISE for classification|n - RISE for Object Detection|n |n + This command has the following syntax:|n + |s|s%(prog)s |n + |n + - a path to the file.|n + - either a dataset path or a revision path. The full + syntax is:|n + - Dataset paths:|n + |s|s- [ : ]|n + - Revision paths:|n + |s|s- [ @ ] [ : ]|n + |s|s- [ : ]|n + |s|s- |n + Parts can be enclosed in quotes.|n + |n + The current project (-p/--project) is used as a context for plugins + and models. It is used when there is a dataset path in target. + When not specified, the current project's working tree is used.|n + |n Examples:|n - Run RISE on an image, display results:|n - |s|s%(prog)s -t path/to/image.jpg -m mymodel rise --max-samples 50 + |s|s%(prog)s path/to/image.jpg -m mymodel rise --max-samples 50|n + |n + - Run RISE on a source revision:|n + |s|s%(prog)s HEAD~1:source-1 -m model rise """, formatter_class=MultilineFormatter) + parser.add_argument('target', nargs='?', default=None, + help="Inference target - image, revpath (default: project)") parser.add_argument('-m', '--model', required=True, help="Model to use for inference") - parser.add_argument('-t', '--target', default=None, - help="Inference target - image, source, project " - "(default: current project)") parser.add_argument('-o', '--output-dir', dest='save_dir', default=None, help="Directory to save output (default: display only)") @@ -95,28 +111,14 @@ def build_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def explain_command(args): - project_path = args.project_dir - if is_project_path(project_path): - project = Project.load(project_path) - else: - project = None - args.target = target_selector( - ProjectTarget(is_default=True, project=project), - SourceTarget(project=project), - ImageTarget() - )(args.target) - if args.target[0] == TargetKinds.project: - if is_project_path(args.target[1]): - args.project_dir = osp.dirname(osp.abspath(args.target[1])) - - from matplotlib import cm import cv2 - project = load_project(args.project_dir) + project = scope_add(load_project(args.project_dir)) - model = project.make_executable_model(args.model) + model = project.working_tree.models.make_executable_model(args.model) if str(args.algorithm).lower() != 'rise': raise NotImplementedError() @@ -132,8 +134,8 @@ def explain_command(args): det_conf_thresh=args.det_conf_thresh, batch_size=args.batch_size) - if args.target[0] == TargetKinds.image: - image_path = args.target[1] + if args.target and is_image(args.target): + image_path = args.target image = load_image(image_path) log.info("Running inference explanation for '%s'" % image_path) @@ -166,23 +168,20 @@ def explain_command(args): disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2 cv2.imshow(file_name + '-heatmap-%s' % j, disp) cv2.waitKey(0) - elif args.target[0] == TargetKinds.source or \ - args.target[0] == TargetKinds.project: - if args.target[0] == TargetKinds.source: - source_name = args.target[1] - dataset = project.make_source_project(source_name).make_dataset() - log.info("Running inference explanation for '%s'" % source_name) - else: - project_name = project.config.project_name - dataset = project.make_dataset() - log.info("Running inference explanation for '%s'" % project_name) + + else: + dataset, target_project = \ + parse_full_revpath(args.target or 'project', project) + if target_project: + scope_add(target_project) + + log.info("Running inference explanation for '%s'" % args.target) for item in dataset: image = item.image.data if image is None: - log.warning( - "Dataset item %s does not have image data. Skipping." % \ - (item.id)) + log.warning("Item %s does not have image data. Skipping.", + item.id) continue heatmap_iter = rise.apply(image) @@ -204,7 +203,5 @@ def explain_command(args): disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2 cv2.imshow(item.id + '-heatmap-%s' % j, disp) cv2.waitKey(0) - else: - raise NotImplementedError() return 0 diff --git a/datumaro/cli/commands/import_.py b/datumaro/cli/commands/import_.py deleted file mode 100644 index 0d6ef829c0..0000000000 --- a/datumaro/cli/commands/import_.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (C) 2019-2021 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from ..contexts.project import build_import_parser as build_parser - -__all__ = [ - 'build_parser', -] diff --git a/datumaro/cli/commands/info.py b/datumaro/cli/commands/info.py index e6c693bbf3..1007f4f4e2 100644 --- a/datumaro/cli/commands/info.py +++ b/datumaro/cli/commands/info.py @@ -2,8 +2,112 @@ # # SPDX-License-Identifier: MIT -from ..contexts.project import build_info_parser as build_parser +import argparse -__all__ = [ - 'build_parser', -] +from datumaro.components.errors import ( + DatasetMergeError, MissingObjectError, ProjectNotFoundError, +) +from datumaro.components.extractor import AnnotationType +from datumaro.util.scope import scope_add, scoped + +from ..util import MultilineFormatter +from ..util.project import load_project, parse_full_revpath + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Prints dataset overview", + description=""" + Prints info about the dataset at , or about the current + project's combined dataset, if none is specified.|n + |n + - either a dataset path or a revision path. The full + syntax is:|n + - Dataset paths:|n + |s|s- [ : ]|n + - Revision paths:|n + |s|s- [ @ ] [ : ]|n + |s|s- [ : ]|n + |s|s- |n + |n + Both forms use the -p/--project as a context for plugins. It can be + useful for dataset paths in targets. When not specified, the current + project's working tree is used.|n + |n + Examples:|n + - Print dataset info for the current project's working tree:|n + |s|s%(prog)s|n + |n + - Print dataset info for a path and a format name:|n + |s|s%(prog)s path/to/dataset:voc|n + |n + - Print dataset info for a source from a past revision:|n + |s|s%(prog)s HEAD~2:source-2 + """, + formatter_class=MultilineFormatter) + + parser.add_argument('target', nargs='?', default='project', + metavar='revpath', + help="Target dataset revpath") + parser.add_argument('--all', action='store_true', + help="Print all information") + parser.add_argument('-p', '--project', dest='project_dir', + help="Directory of the current project (default: current dir)") + parser.set_defaults(command=info_command) + + return parser + +@scoped +def info_command(args): + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if args.project_dir: + raise + + try: + # TODO: avoid computing working tree hashes + dataset, target_project = parse_full_revpath(args.target, project) + if target_project: + scope_add(target_project) + except DatasetMergeError as e: + dataset = None + dataset_problem = "Can't merge project sources automatically: %s " \ + "Conflicting sources are: %s" % (e, ', '.join(e.sources)) + except MissingObjectError as e: + dataset = None + dataset_problem = str(e) + + def print_dataset_info(dataset, indent=''): + print("%slength:" % indent, len(dataset)) + + categories = dataset.categories() + print("%scategories:" % indent, ', '.join(c.name for c in categories)) + + for cat_type, cat in categories.items(): + print("%s %s:" % (indent, cat_type.name)) + if cat_type == AnnotationType.label: + print("%s count:" % indent, len(cat.items)) + + count_threshold = 10 + if args.all: + count_threshold = len(cat.items) + labels = ', '.join(c.name for c in cat.items[:count_threshold]) + if count_threshold < len(cat.items): + labels += " (and %s more)" % ( + len(cat.items) - count_threshold) + print("%s labels:" % indent, labels) + + if dataset is not None: + print_dataset_info(dataset) + + subsets = dataset.subsets() + print("subsets:", ', '.join(subsets)) + for subset_name in subsets: + subset = dataset.get_subset(subset_name) + print(" '%s':" % subset_name) + print_dataset_info(subset, indent=" ") + else: + print("Dataset info is not available: ", dataset_problem) + + return 0 diff --git a/datumaro/cli/commands/log.py b/datumaro/cli/commands/log.py new file mode 100644 index 0000000000..75452a325b --- /dev/null +++ b/datumaro/cli/commands/log.py @@ -0,0 +1,34 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from datumaro.util.scope import scope_add, scoped + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(description="Prints project history.") + + parser.add_argument('-n', '--max-count', default=10, type=int, + help="Count of last commits to print (default: %(default)s)") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=log_command) + + return parser + +@scoped +def log_command(args): + project = scope_add(load_project(args.project_dir)) + + revisions = project.history(args.max_count) + if revisions: + for rev, message in revisions: + print('%s %s' % (rev, message)) + else: + print("(Project history is empty)") + + return 0 diff --git a/datumaro/cli/commands/merge.py b/datumaro/cli/commands/merge.py index 153675faaa..565bbfbf9d 100644 --- a/datumaro/cli/commands/merge.py +++ b/datumaro/cli/commands/merge.py @@ -9,37 +9,77 @@ import os import os.path as osp -from datumaro.components.errors import DatasetMergeError, DatasetQualityError +from datumaro.components.dataset import DEFAULT_FORMAT +from datumaro.components.errors import ( + DatasetMergeError, DatasetQualityError, ProjectNotFoundError, +) from datumaro.components.operations import IntersectMerge -from datumaro.components.project import Project +from datumaro.util.scope import scope_add, scoped -from ..util import CliException, MultilineFormatter, at_least -from ..util.project import generate_next_file_name, load_project +from ..util import MultilineFormatter +from ..util.errors import CliException +from ..util.project import ( + generate_next_file_name, load_project, parse_full_revpath, +) def build_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Merge few projects", description=""" - Merges multiple datasets into one. This can be useful if you - have few annotations and wish to merge them, - taking into consideration potential overlaps and conflicts. - This command can try to find a common ground by voting or - return a list of conflicts.|n - |n - Examples:|n - - Merge annotations from 3 (or more) annotators:|n - |s|smerge project1/ project2/ project3/|n - - Check groups of the merged dataset for consistence:|n - |s|s|slook for groups consising of 'person', 'hand' 'head', 'foot'|n - |s|smerge project1/ project2/ -g 'person,hand?,head,foot?' + Merges multiple datasets into one and produces a new + dataset in the default format. This can be useful if you + have few annotations and wish to merge them, + taking into consideration potential overlaps and conflicts. + This command can try to find a common ground by voting or + return a list of conflicts.|n + |n + This command has multiple forms:|n + 1) %(prog)s |n + 2) %(prog)s ...|n + |n + 1 - Merges the current project's main target ('project') + in the working tree with the specified dataset.|n + 2 - Merges the specified datasets. + Note that the current project is not included in the list of merged + sources automatically.|n + |n + - either a dataset path or a revision path. The full + syntax is:|n + - Dataset paths:|n + |s|s- [ : ]|n + - Revision paths:|n + |s|s- [ @ ] [ : ]|n + |s|s- [ : ]|n + |s|s- |n + |n + The current project (-p/--project) is used as a context for plugins. + It can be useful for dataset paths in targets. When not specified, + the current project's working tree is used.|n + |n + Examples:|n + - Merge annotations from 3 (or more) annotators:|n + |s|s%(prog)s project1/ project2/ project3/|n + |n + - Check groups of the merged dataset for consistency:|n + |s|s|slook for groups consising of 'person', 'hand' 'head', 'foot'|n + |s|s%(prog)s project1/ project2/ -g 'person,hand?,head,foot?'|n + |n + - Merge two datasets, specify formats:|n + |s|s%(prog)s path/to/dataset1:voc path/to/dataset2:coco|n + |n + - Merge the current working tree and a dataset:|n + |s|s%(prog)s path/to/dataset2:coco|n + |n + - Merge a source from a previous revision and a dataset:|n + |s|s%(prog)s HEAD~2:source-2 path/to/dataset2:yolo """, formatter_class=MultilineFormatter) def _group(s): return s.split(',') - parser.add_argument('project', nargs='+', action=at_least(2), - help="Path to a project (repeatable)") + parser.add_argument('targets', nargs='+', + help="Target dataset revpaths (repeatable)") parser.add_argument('-iou', '--iou-thresh', default=0.25, type=float, help="IoU match threshold for segments (default: %(default)s)") parser.add_argument('-oconf', '--output-conf-thresh', @@ -50,21 +90,21 @@ def _group(s): help="Minimum count for a label and attribute voting " "results to be counted (default: %(default)s)") parser.add_argument('-g', '--groups', action='append', type=_group, - default=[], help="A comma-separated list of labels in " "annotation groups to check. '?' postfix can be added to a label to" "make it optional in the group (repeatable)") parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, - help="Output directory (default: current project's dir)") + help="Output directory (default: generate a new one)") parser.add_argument('--overwrite', action='store_true', help="Overwrite existing files in the save directory") + parser.add_argument('-p', '--project', dest='project_dir', + help="Directory of the 'current' project (default: current dir)") parser.set_defaults(command=merge_command) return parser +@scoped def merge_command(args): - source_projects = [load_project(p) for p in args.project] - dst_dir = args.dst_dir if dst_dir: if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): @@ -72,28 +112,38 @@ def merge_command(args): "(pass --overwrite to overwrite)" % dst_dir) else: dst_dir = generate_next_file_name('merged') + dst_dir = osp.abspath(dst_dir) + + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if args.project_dir: + raise source_datasets = [] - for p in source_projects: - log.debug("Loading project '%s' dataset", p.config.project_name) - source_datasets.append(p.make_dataset()) + try: + if len(args.targets) == 1: + source_datasets.append(project.working_tree.make_dataset()) + + for t in args.targets: + target_dataset, target_project = parse_full_revpath(t, project) + if target_project: + scope_add(target_project) + source_datasets.append(target_dataset) + except Exception as e: + raise CliException(str(e)) merger = IntersectMerge(conf=IntersectMerge.Conf( - pairwise_dist=args.iou_thresh, groups=args.groups, + pairwise_dist=args.iou_thresh, groups=args.groups or [], output_conf_thresh=args.output_conf_thresh, quorum=args.quorum )) merged_dataset = merger(source_datasets) - - merged_project = Project() - output_dataset = merged_project.make_dataset() - output_dataset.define_categories(merged_dataset.categories()) - merged_dataset = output_dataset.update(merged_dataset) - merged_dataset.save(save_dir=dst_dir) + merged_dataset.export(save_dir=dst_dir, format=DEFAULT_FORMAT) report_path = osp.join(dst_dir, 'merge_report.json') save_merge_report(merger, report_path) - dst_dir = osp.abspath(dst_dir) log.info("Merge results have been saved to '%s'" % dst_dir) log.info("Report has been saved to '%s'" % report_path) diff --git a/datumaro/cli/commands/status.py b/datumaro/cli/commands/status.py new file mode 100644 index 0000000000..5900b7a29a --- /dev/null +++ b/datumaro/cli/commands/status.py @@ -0,0 +1,45 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import argparse + +from datumaro.cli.util import MultilineFormatter +from datumaro.util.scope import scope_add, scoped + +from ..util.project import load_project + + +def build_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Prints project status.", + description=""" + This command prints the summary of the project changes between + the working tree of a project and its HEAD revision. + """, + formatter_class=MultilineFormatter + ) + + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.set_defaults(command=status_command) + + return parser + +@scoped +def status_command(args): + project = scope_add(load_project(args.project_dir)) + + statuses = project.status() + + if project.branch: + print("On branch '%s', commit %s" % (project.branch, project.head_rev)) + else: + print("HEAD is detached at commit %s" % project.head_rev) + + if statuses: + for target, status in statuses.items(): + print('%s\t%s' % (status.name, target)) + else: + print("Working directory clean") + + return 0 diff --git a/datumaro/cli/contexts/model.py b/datumaro/cli/contexts/model.py index 178b389deb..799d99a0e1 100644 --- a/datumaro/cli/contexts/model.py +++ b/datumaro/cli/contexts/model.py @@ -6,61 +6,79 @@ import logging as log import os import os.path as osp -import shutil +from datumaro.components.errors import ProjectNotFoundError from datumaro.components.project import Environment -from datumaro.util.scope import on_error_do, scoped +from datumaro.util.os_util import rmtree +from datumaro.util.scope import on_error_do, scope_add, scoped -from ..util import CliException, MultilineFormatter, add_subparser +from ..util import MultilineFormatter, add_subparser +from ..util.errors import CliException from ..util.project import ( generate_next_file_name, generate_next_name, load_project, + parse_full_revpath, ) def build_add_parser(parser_ctor=argparse.ArgumentParser): - builtins = sorted(Environment().launchers.items) + builtins = sorted(Environment().launchers) parser = parser_ctor(help="Add model to project", description=""" - Registers an executable model into a project. A model requires - a launcher to be executed. Each launcher has its own options, which - are passed after '--' separator, pass '-- -h' for more info. - |n - List of builtin launchers: %s - """ % ', '.join(builtins), + Adds an executable model into a project. A model requires + a launcher to be executed. Each launcher has its own options, which + are passed after the '--' separator, pass '-- -h' for more info. + |n + List of builtin launchers: {}|n + |n + Examples:|n + - Add an OpenVINO model into a project:|n + |s|s%(prog)s -l openvino -- -d model.xml -w model.bin -i parse_outs.py + """.format(', '.join(builtins)), formatter_class=MultilineFormatter) + parser.add_argument('-n', '--name', default=None, + help="Name of the model to be added (default: generate automatically)") parser.add_argument('-l', '--launcher', required=True, help="Model launcher") - parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, - help="Additional arguments for converter (pass '-- -h' for help)") parser.add_argument('--copy', action='store_true', - help="Copy the model to the project") - parser.add_argument('-n', '--name', default=None, - help="Name of the model to be added (default: generate automatically)") - parser.add_argument('--overwrite', action='store_true', - help="Overwrite if exists") + help="Copy model data into project (default: %(default)s)") + parser.add_argument('--no-check', action='store_true', + help="Don't check model loading (default: %(default)s)") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") + parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, + help="Additional arguments for converter (pass '-- -h' for help)") parser.set_defaults(command=add_command) return parser @scoped def add_command(args): - project = load_project(args.project_dir) + show_plugin_help = '-h' in args.extra_args or '--help' in args.extra_args - if args.name: - if not args.overwrite and args.name in project.config.models: - raise CliException("Model '%s' already exists " - "(pass --overwrite to overwrite)" % args.name) + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if not show_plugin_help and args.project_dir: + raise + + if project is not None: + env = project.env else: - args.name = generate_next_name( - project.config.models, 'model', '-', default=0) - assert args.name not in project.config.models, args.name + env = Environment() + + name = args.name + if name: + if name in project.models: + raise CliException("Model '%s' already exists" % name) + else: + name = generate_next_name(list(project.models), + 'model', sep='-', default=0) try: - launcher = project.env.launchers[args.launcher] + launcher = env.launchers[args.launcher] except KeyError: raise CliException("Launcher '%s' is not found" % args.launcher) @@ -70,33 +88,34 @@ def add_command(args): if args.copy: log.info("Copying model data") - model_dir = osp.join(project.config.project_dir, - project.local_model_dir(args.name)) + model_dir = project.model_data_dir(name) os.makedirs(model_dir, exist_ok=False) - on_error_do(shutil.rmtree, model_dir, ignore_errors=True) + on_error_do(rmtree, model_dir, ignore_errors=True) try: cli_plugin.copy_model(model_dir, model_args) except (AttributeError, NotImplementedError): - log.error("Can't copy: copying is not available for '%s' models" % \ + raise NotImplementedError( + "Can't copy: copying is not available for '%s' models. " % args.launcher) - log.info("Checking the model") - project.add_model(args.name, { - 'launcher': args.launcher, - 'options': model_args, - }) - project.make_executable_model(args.name) + project.add_model(name, launcher=args.launcher, options=model_args) + on_error_do(project.remove_model, name, ignore_errors=True) + + if not args.no_check: + log.info("Checking the model...") + project.make_model(name) project.save() - log.info("Model '%s' with launcher '%s' has been added to project '%s'" % \ - (args.name, args.launcher, project.config.project_name)) + log.info("Model '%s' with launcher '%s' has been added to project", + name, args.launcher) return 0 def build_remove_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor() + parser = parser_ctor(help="Remove model from project", + description="Remove a model from a project") parser.add_argument('name', help="Name of the model to be removed") @@ -106,8 +125,9 @@ def build_remove_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def remove_command(args): - project = load_project(args.project_dir) + project = scope_add(load_project(args.project_dir)) project.remove_model(args.name) project.save() @@ -115,35 +135,57 @@ def remove_command(args): return 0 def build_run_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor() + parser = parser_ctor(help="Launches model inference", + description=""" + Launches model inference on a dataset.|n + |n + Target dataset is specified by a revpath. The full syntax is:|n + - Dataset paths:|n + |s|s- [ : ]|n + - Revision paths:|n + |s|s- [ @ ] [ : ]|n + |s|s- [ : ]|n + |s|s- |n + |n + Both forms use the -p/--project as a context for plugins and models. + When not specified, the current project's working tree is used.|n + """, + formatter_class=MultilineFormatter) + parser.add_argument('target', nargs='?', default='project', + help="Target dataset revpath (default: %(default)s)") parser.add_argument('-o', '--output-dir', dest='dst_dir', - help="Directory to save output") + help="Directory to save output (default: auto-generated)") parser.add_argument('-m', '--model', dest='model_name', required=True, help="Model to apply to the project") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") parser.add_argument('--overwrite', action='store_true', - help="Overwrite if exists") + help="Overwrite output directory if exists") parser.set_defaults(command=run_command) return parser +@scoped def run_command(args): - project = load_project(args.project_dir) - dst_dir = args.dst_dir if dst_dir: if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): raise CliException("Directory '%s' already exists " "(pass --overwrite overwrite)" % dst_dir) else: - dst_dir = generate_next_file_name('%s-inference' % \ - project.config.project_name) + dst_dir = generate_next_file_name('%s-inference' % args.model_name) + dst_dir = osp.abspath(dst_dir) - project.make_dataset().apply_model( - save_dir=osp.abspath(dst_dir), - model=args.model_name) + project = scope_add(load_project(args.project_dir)) + + dataset, target_project = parse_full_revpath(args.target, project) + if target_project: + scope_add(target_project) + + model = project.make_model(args.model_name) + inference = dataset.run_model(model) + inference.save(dst_dir) log.info("Inference results have been saved to '%s'" % dst_dir) @@ -162,14 +204,14 @@ def build_info_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def info_command(args): - project = load_project(args.project_dir) + project = scope_add(load_project(args.project_dir)) if args.name: - model = project.get_model(args.name) - print(model) + print(project.models[args.name]) else: - for name, conf in project.config.models.items(): + for name, conf in project.models.items(): print(name) if args.verbose: print(dict(conf)) diff --git a/datumaro/cli/contexts/project/__init__.py b/datumaro/cli/contexts/project/__init__.py index 18d68c5a7a..cf42367675 100644 --- a/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/cli/contexts/project/__init__.py @@ -8,225 +8,26 @@ import logging as log import os import os.path as osp -import shutil import numpy as np -from datumaro.components.annotation import AnnotationType from datumaro.components.dataset_filter import DatasetItemEncoder +from datumaro.components.environment import Environment +from datumaro.components.errors import MigrationError, ProjectNotFoundError from datumaro.components.operations import ( - DistanceComparator, ExactComparator, compute_ann_statistics, - compute_image_statistics, + compute_ann_statistics, compute_image_statistics, ) -from datumaro.components.project import PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG -from datumaro.components.project import Environment, Project +from datumaro.components.project import Project, ProjectBuildTargets from datumaro.components.validator import TaskType +from datumaro.util import str_to_bool from datumaro.util.os_util import make_file_name -from datumaro.util.scope import on_error_do, scoped +from datumaro.util.scope import scope_add, scoped -from ...util import CliException, MultilineFormatter, add_subparser -from ...util.project import generate_next_file_name, load_project -from .diff import DiffVisualizer - - -def build_create_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor(help="Create empty project", - description=""" - Create a new empty project.|n - |n - Examples:|n - - Create a project in the current directory:|n - |s|screate -n myproject|n - |n - - Create a project in other directory:|n - |s|screate -o path/I/like/ - """, - formatter_class=MultilineFormatter) - - parser.add_argument('-o', '--output-dir', default='.', dest='dst_dir', - help="Save directory for the new project (default: current dir") - parser.add_argument('-n', '--name', default=None, - help="Name of the new project (default: same as project dir)") - parser.add_argument('--overwrite', action='store_true', - help="Overwrite existing files in the save directory") - parser.set_defaults(command=create_command) - - return parser - -def create_command(args): - project_dir = osp.abspath(args.dst_dir) - - project_env_dir = osp.join(project_dir, DEFAULT_CONFIG.env_dir) - if osp.isdir(project_env_dir) and os.listdir(project_env_dir): - if not args.overwrite: - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % project_env_dir) - else: - shutil.rmtree(project_env_dir, ignore_errors=True) - - own_dataset_dir = osp.join(project_dir, DEFAULT_CONFIG.dataset_dir) - if osp.isdir(own_dataset_dir) and os.listdir(own_dataset_dir): - if not args.overwrite: - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % own_dataset_dir) - else: - # NOTE: remove the dir to avoid using data from previous project - shutil.rmtree(own_dataset_dir) - - project_name = args.name - if project_name is None: - project_name = osp.basename(project_dir) - - log.info("Creating project at '%s'" % project_dir) - - Project.generate(project_dir, { - 'project_name': project_name, - }) - - log.info("Project has been created at '%s'" % project_dir) - - return 0 - -def build_import_parser(parser_ctor=argparse.ArgumentParser): - builtins = sorted(Environment().importers.items) - - parser = parser_ctor(help="Create project from an existing dataset", - description=""" - Creates a project from an existing dataset. The source can be:|n - - a dataset in a supported format (check 'formats' section below)|n - - a Datumaro project|n - |n - Formats:|n - Datasets come in a wide variety of formats. Each dataset - format defines its own data structure and rules on how to - interpret the data. For example, the following data structure - is used in COCO format:|n - /dataset/|n - - /images/.jpg|n - - /annotations/|n - |n - In Datumaro dataset formats are supported by - Extractor-s and Importer-s. - An Extractor produces a list of dataset items corresponding - to the dataset. An Importer creates a project from the - data source location. - It is possible to add a custom Extractor and Importer. - To do this, you need to put an Extractor and - Importer implementation scripts to - /.datumaro/extractors - and /.datumaro/importers.|n - |n - List of builtin dataset formats: %s|n - |n - Examples:|n - - Create a project from VOC dataset in the current directory:|n - |s|simport -f voc -i path/to/voc|n - |n - - Create a project from COCO dataset in other directory:|n - |s|simport -f coco -i path/to/coco -o path/I/like/ - """ % ', '.join(builtins), - formatter_class=MultilineFormatter) - - parser.add_argument('-o', '--output-dir', default='.', dest='dst_dir', - help="Directory to save the new project to (default: current dir)") - parser.add_argument('-n', '--name', default=None, - help="Name of the new project (default: same as project dir)") - parser.add_argument('--copy', action='store_true', - help="Copy the dataset instead of saving source links") - parser.add_argument('--skip-check', action='store_true', - help="Skip source checking") - parser.add_argument('--overwrite', action='store_true', - help="Overwrite existing files in the save directory") - parser.add_argument('-i', '--input-path', required=True, dest='source', - help="Path to import project from") - parser.add_argument('-f', '--format', - help="Source project format. Will try to detect, if not specified.") - parser.add_argument('extra_args', nargs=argparse.REMAINDER, - help="Additional arguments for importer (pass '-- -h' for help)") - parser.set_defaults(command=import_command) - - return parser - -def import_command(args): - project_dir = osp.abspath(args.dst_dir) - - project_env_dir = osp.join(project_dir, DEFAULT_CONFIG.env_dir) - if osp.isdir(project_env_dir) and os.listdir(project_env_dir): - if not args.overwrite: - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % project_env_dir) - else: - shutil.rmtree(project_env_dir, ignore_errors=True) - - own_dataset_dir = osp.join(project_dir, DEFAULT_CONFIG.dataset_dir) - if osp.isdir(own_dataset_dir) and os.listdir(own_dataset_dir): - if not args.overwrite: - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % own_dataset_dir) - else: - # NOTE: remove the dir to avoid using data from previous project - shutil.rmtree(own_dataset_dir) - - project_name = args.name - if project_name is None: - project_name = osp.basename(project_dir) - - env = Environment() - log.info("Importing project from '%s'" % args.source) - - extra_args = {} - fmt = args.format - if not args.format: - if args.extra_args: - raise CliException("Extra args can not be used without format") - - log.info("Trying to detect dataset format...") - - matches = env.detect_dataset(args.source) - if len(matches) == 0: - log.error("Failed to detect dataset format. " - "Try to specify format with '-f/--format' parameter.") - return 1 - elif len(matches) != 1: - log.error("Multiple formats match the dataset: %s. " - "Try to specify format with '-f/--format' parameter.", - ', '.join(matches)) - return 1 - - fmt = matches[0] - elif args.extra_args: - if fmt in env.importers: - arg_parser = env.importers[fmt] - elif fmt in env.extractors: - arg_parser = env.extractors[fmt] - else: - raise CliException("Unknown format '%s'. A format can be added" - "by providing an Extractor and Importer plugins" % fmt) - - if hasattr(arg_parser, 'parse_cmdline'): - extra_args = arg_parser.parse_cmdline(args.extra_args) - else: - raise CliException("Format '%s' does not accept " - "extra parameters" % fmt) - - log.info("Importing project as '%s'" % fmt) - - project = Project.import_from(osp.abspath(args.source), fmt, **extra_args) - project.config.project_name = project_name - project.config.project_dir = project_dir - - if not args.skip_check or args.copy: - log.info("Checking the dataset...") - dataset = project.make_dataset() - if args.copy: - log.info("Cloning data...") - dataset.save(merge=True, save_images=True) - else: - project.save() - - log.info("Project has been created at '%s'" % project_dir) - - return 0 +from ...util import MultilineFormatter, add_subparser +from ...util.errors import CliException +from ...util.project import ( + generate_next_file_name, load_project, parse_full_revpath, +) class FilterModes(Enum): @@ -269,43 +70,57 @@ def list_options(cls): return [m.name.replace('_', '+') for m in cls] def build_export_parser(parser_ctor=argparse.ArgumentParser): - builtins = sorted(Environment().converters.items) + builtins = sorted(Environment().converters) parser = parser_ctor(help="Export project", description=""" - Exports the project dataset in some format. Optionally, a filter - can be passed, check 'filter' command description for more info. - Each dataset format has its own options, which - are passed after '--' separator (see examples), pass '-- -h' - for more info. If not stated otherwise, by default - only annotations are exported, to include images pass - '--save-images' parameter.|n - |n - Formats:|n - In Datumaro dataset formats are supported by Converter-s. - A Converter produces a dataset of a specific format - from dataset items. It is possible to add a custom Converter. - To do this, you need to put a Converter - definition script to /.datumaro/converters.|n - |n - List of builtin dataset formats: %s|n - |n - Examples:|n - - Export project as a VOC-like dataset, include images:|n - |s|sexport -f voc -- --save-images|n - |n - - Export project as a COCO-like dataset in other directory:|n - |s|sexport -f coco -o path/I/like/ - """ % ', '.join(builtins), + Exports a project in some format.|n + |n + Each dataset format has its own export + options, which are passed after the '--' separator (see examples), + pass '-- -h' for more info. If not stated otherwise, by default + only annotations are exported, to include images pass + '--save-images' parameter.|n + |n + A filter can be passed, check the 'filter' command description for + more info.|n + |n + Formats:|n + Datasets come in a wide variety of formats. Each dataset + format defines its own data structure and rules on how to + interpret the data. Check the user manual for the list of + supported formats, examples and documentation. + |n + The list of supported formats can be extended by plugins. + Check the "plugins" section of the developer guide for information + about plugin implementation.|n + |n + List of builtin dataset formats: {}|n + |n + The command can only be applied to a project build target, a stage + or the combined 'project' target, in which case all the targets will + be affected. + |n + Examples:|n + - Export project as a VOC-like dataset, include images:|n + |s|s%(prog)s -f voc -- --save-images|n + |n + - Export project as a COCO-like dataset in other directory:|n + |s|s%(prog)s -f coco -o path/I/like/ + """.format(', '.join(builtins)), formatter_class=MultilineFormatter) - parser.add_argument('-e', '--filter', default=None, - help="Filter expression for dataset items") + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # workaround for -- eaten by positionals + parser.add_argument('target', nargs='?', default='project', + help="A project target to be exported (default: %(default)s)") + parser.add_argument('-e', '--filter', + help="XML XPath filter expression for dataset items") parser.add_argument('--filter-mode', default=FilterModes.i.name, type=FilterModes.parse, help="Filter mode (options: %s; default: %s)" % \ (', '.join(FilterModes.list_options()) , '%(default)s')) - parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, + parser.add_argument('-o', '--output-dir', dest='dst_dir', help="Directory to save output (default: a subdir in the current one)") parser.add_argument('--overwrite', action='store_true', help="Overwrite existing files in the save directory") @@ -313,87 +128,126 @@ def build_export_parser(parser_ctor=argparse.ArgumentParser): help="Directory of the project to operate on (default: current dir)") parser.add_argument('-f', '--format', required=True, help="Output format") - parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, - help="Additional arguments for converter (pass '-- -h' for help)") + parser.add_argument('extra_args', nargs=argparse.REMAINDER, + help="Additional arguments for converter (pass '-- -h' for help). " + "Must be specified after the main command arguments and after " + "the '--' separator") parser.set_defaults(command=export_command) return parser +@scoped def export_command(args): - project = load_project(args.project_dir) + has_sep = '--' in args._positionals + if has_sep: + pos = args._positionals.index('--') + if 1 < pos: + raise argparse.ArgumentError(None, + message="Expected no more than 1 target argument") + else: + pos = 1 + args.target = (args._positionals[:pos] or \ + [ProjectBuildTargets.MAIN_TARGET])[0] + args.extra_args = args._positionals[pos + has_sep:] - dst_dir = args.dst_dir - if dst_dir: - if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % dst_dir) + show_plugin_help = '-h' in args.extra_args or '--help' in args.extra_args + + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if not show_plugin_help and args.project_dir: + raise + + if project is not None: + env = project.env else: - dst_dir = generate_next_file_name('%s-%s' % \ - (project.config.project_name, make_file_name(args.format))) - dst_dir = osp.abspath(dst_dir) + env = Environment() try: - converter = project.env.converters[args.format] + converter = env.converters[args.format] except KeyError: raise CliException("Converter for format '%s' is not found" % \ args.format) + extra_args = converter.parse_cmdline(args.extra_args) - filter_args = FilterModes.make_filter_args(args.filter_mode) + dst_dir = args.dst_dir + if dst_dir: + if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): + raise CliException("Directory '%s' already exists " + "(pass --overwrite to overwrite)" % dst_dir) + else: + dst_dir = generate_next_file_name('export-%s' % \ + make_file_name(args.format)) + dst_dir = osp.abspath(dst_dir) - log.info("Loading the project...") - dataset = project.make_dataset() + if args.filter: + filter_args = FilterModes.make_filter_args(args.filter_mode) + filter_expr = args.filter - log.info("Exporting the project...") + log.info("Loading the project...") + target = args.target if args.filter: - dataset = dataset.filter(args.filter, **filter_args) - converter = project.env.converters[args.format] - converter.convert(dataset, save_dir=dst_dir, **extra_args) + target = project.working_tree.build_targets.add_filter_stage( + target, expr=filter_expr, params=filter_args) + + log.info("Exporting...") + + dataset = project.working_tree.make_dataset(target) + dataset.export(save_dir=dst_dir, format=converter, **extra_args) - log.info("Project exported to '%s' as '%s'" % (dst_dir, args.format)) + log.info("Results have been saved to '%s'" % dst_dir) return 0 def build_filter_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor(help="Extract subproject", + parser = parser_ctor(help="Extract subdataset", description=""" - Extracts a subproject that contains only items matching filter. - A filter is an XPath expression, which is applied to XML - representation of a dataset item. Check '--dry-run' parameter - to see XML representations of the dataset items.|n - |n - To filter annotations use the mode ('-m') parameter.|n - Supported modes:|n - - 'i', 'items'|n - - 'a', 'annotations'|n - - 'i+a', 'a+i', 'items+annotations', 'annotations+items'|n - When filtering annotations, use the 'items+annotations' - mode to point that annotation-less dataset items should be - removed. To select an annotation, write an XPath that - returns 'annotation' elements (see examples).|n - |n - Examples:|n - - Filter images with width < height:|n - |s|sextract -e '/item[image/width < image/height]'|n - |n - - Filter images with large-area bboxes:|n - |s|sextract -e '/item[annotation/type="bbox" and - annotation/area>2000]'|n - |n - - Filter out all irrelevant annotations from items:|n - |s|sextract -m a -e '/item/annotation[label = "person"]'|n - |n - - Filter out all irrelevant annotations from items:|n - |s|sextract -m a -e '/item/annotation[label="cat" and - area > 99.5]'|n - |n - - Filter occluded annotations and items, if no annotations left:|n - |s|sextract -m i+a -e '/item/annotation[occluded="True"]' + Extracts a subdataset that contains only items matching filter. + A filter is an XPath expression, which is applied to XML + representation of a dataset item. Check '--dry-run' parameter + to see XML representations of the dataset items.|n + |n + To filter annotations use the mode ('-m') parameter.|n + Supported modes:|n + - 'i', 'items'|n + - 'a', 'annotations'|n + - 'i+a', 'a+i', 'items+annotations', 'annotations+items'|n + When filtering annotations, use the 'items+annotations' + mode to point that annotation-less dataset items should be + removed. To select an annotation, write an XPath that + returns 'annotation' elements (see examples).|n + |n + The command can only be applied to a project build target, a stage + or the combined 'project' target, in which case all the targets will + be affected. A build tree stage will be added if '--stage' is enabled, + and the resulting dataset(-s) will be saved if '--apply' is enabled. + |n + Examples:|n + - Filter images with width < height:|n + |s|s%(prog)s -e '/item[image/width < image/height]'|n + |n + - Filter images with large-area bboxes:|n + |s|s%(prog)s -e '/item[annotation/type="bbox" and + annotation/area>2000]'|n + |n + - Filter out all irrelevant annotations from items:|n + |s|s%(prog)s -m a -e '/item/annotation[label = "person"]'|n + |n + - Filter out all irrelevant annotations from items:|n + |s|s%(prog)s -m a -e '/item/annotation[label="cat" and + area > 99.5]'|n + |n + - Filter occluded annotations and items, if no annotations left:|n + |s|s%(prog)s -m i+a -e '/item/annotation[occluded="True"]' """, formatter_class=MultilineFormatter) - parser.add_argument('-e', '--filter', default=None, + parser.add_argument('target', nargs='?', default='project', + help="A project target to apply transform to (default: %(default)s)") + parser.add_argument('-e', '--filter', help="XML XPath filter expression for dataset items") parser.add_argument('-m', '--mode', default=FilterModes.i.name, type=FilterModes.parse, @@ -401,8 +255,24 @@ def build_filter_parser(parser_ctor=argparse.ArgumentParser): (', '.join(FilterModes.list_options()) , '%(default)s')) parser.add_argument('--dry-run', action='store_true', help="Print XML representations to be filtered and exit") - parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, - help="Output directory (default: update current project)") + parser.add_argument('-o', '--output-dir', dest='dst_dir', + help=""" + Output directory. Can be omitted for data source targets + (i.e. not intermediate stages) and the 'project' target, + in which case the results will be saved inplace in the + working tree. + """) + parser.add_argument('--stage', type=str_to_bool, default=True, + help=""" + Include this action as a project build step. + If true, this operation will be saved in the project + build tree, allowing to reproduce the resulting dataset later. + Applicable only to data source targets (i.e. not intermediate + stages) and the 'project' target (default: %(default)s) + """) + parser.add_argument('--apply', type=str_to_bool, default=True, + help="Run this command immediately. If disabled, only the " + "build tree stage will be written (default: %(default)s)") parser.add_argument('--overwrite', action='store_true', help="Overwrite existing files in the save directory") parser.add_argument('-p', '--project', dest='project_dir', default='.', @@ -411,26 +281,29 @@ def build_filter_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def filter_command(args): - project = load_project(args.project_dir) - - if not args.dry_run: - dst_dir = args.dst_dir - if dst_dir: - if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % dst_dir) - else: - dst_dir = generate_next_file_name('%s-filter' % \ - project.config.project_name) - dst_dir = osp.abspath(dst_dir) + project = scope_add(load_project(args.project_dir)) - dataset = project.make_dataset() + # TODO: check if we can accept a dataset revpath here + if not args.dry_run and args.stage and \ + args.target not in project.working_tree.build_targets: + raise CliException("Adding a stage is only allowed for " + "source and 'project' targets, not '%s'" % args.target) + + dst_dir = args.dst_dir + if not args.dry_run and dst_dir: + if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): + raise CliException("Directory '%s' already exists " + "(pass --overwrite to overwrite)" % dst_dir) + dst_dir = osp.abspath(dst_dir) filter_args = FilterModes.make_filter_args(args.mode) + filter_expr = args.filter if args.dry_run: - dataset = dataset.filter(expr=args.filter, **filter_args) + dataset = project.working_tree.make_dataset(args.target) + dataset = dataset.filter(expr=filter_expr, **filter_args) for item in dataset: encoded_item = DatasetItemEncoder.encode(item, dataset.categories()) xml_item = DatasetItemEncoder.to_string(encoded_item) @@ -440,278 +313,233 @@ def filter_command(args): if not args.filter: raise CliException("Expected a filter expression ('-e' argument)") - dataset.filter_project(save_dir=dst_dir, - filter_expr=args.filter, **filter_args) - - log.info("Subproject has been extracted to '%s'" % dst_dir) - - return 0 + if args.target == ProjectBuildTargets.MAIN_TARGET: + targets = list(project.working_tree.sources) + else: + targets = [args.target] -def build_merge_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor(help="Merge two projects", - description=""" - Updates items of the current project with items - from other project.|n - |n - Examples:|n - - Update a project with items from other project:|n - |s|smerge -p path/to/first/project path/to/other/project - """, - formatter_class=MultilineFormatter) + for target in targets: + project.working_tree.build_targets.add_filter_stage(target, + expr=filter_expr, params=filter_args) - parser.add_argument('other_project_dir', - help="Path to a project") - parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, - help="Output directory (default: current project's dir)") - parser.add_argument('--overwrite', action='store_true', - help="Overwrite existing files in the save directory") - parser.add_argument('-p', '--project', dest='project_dir', default='.', - help="Directory of the project to operate on (default: current dir)") - parser.set_defaults(command=merge_command) - - return parser + if args.apply: + log.info("Filtering...") -def merge_command(args): - first_project = load_project(args.project_dir) - second_project = load_project(args.other_project_dir) + if args.dst_dir: + dataset = project.working_tree.make_dataset(args.target) + dataset.save(dst_dir, save_images=True) - dst_dir = args.dst_dir - if dst_dir: - if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % dst_dir) + log.info("Results have been saved to '%s'" % dst_dir) + else: + for target in targets: + dataset = project.working_tree.make_dataset(target) - first_dataset = first_project.make_dataset() - second_dataset = second_project.make_dataset() + # Source might be missing in the working dir, so we specify + # the output directory. + # We specify save_images here as a heuristic. It can probably + # be improved by checking if there are images in the dataset + # directory. + dataset.save(project.source_data_dir(target), save_images=True) - first_dataset.update(second_dataset) - first_dataset.save(save_dir=dst_dir) + log.info("Finished") - if dst_dir is None: - dst_dir = first_project.config.project_dir - dst_dir = osp.abspath(dst_dir) - log.info("Merge results have been saved to '%s'" % dst_dir) + if args.stage: + for target_name in targets: + project.refresh_source_hash(target_name) + project.working_tree.save() return 0 -def build_diff_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor(help="Compare projects", +def build_transform_parser(parser_ctor=argparse.ArgumentParser): + builtins = sorted(Environment().transforms) + + parser = parser_ctor(help="Transform project", description=""" - Compares two projects, match annotations by distance.|n + Applies a batch operation to dataset and produces a new dataset.|n + |n + Builtin transforms: {}|n + |n + The command can only be applied to a project build target, a stage + or the combined 'project' target, in which case all the targets will + be affected. A build tree stage will be added if '--stage' is enabled, + and the resulting dataset(-s) will be saved if '--apply' is enabled. |n Examples:|n - - Compare two projects, match boxes if IoU > 0.7,|n - |s|s|s|sprint results to Tensorboard: - |s|sdiff path/to/other/project -o diff/ -v tensorboard --iou-thresh 0.7 - """, + - Convert instance polygons to masks:|n + |s|s%(prog)s -t polygons_to_masks|n + - Rename dataset items by a regular expression|n + |s|s- Replace 'pattern' with 'replacement'|n|n + |s|s%(prog)s -t rename -- -e '|pattern|replacement|'|n + |s|s- Remove 'frame_' from item ids|n + |s|s%(prog)s -t rename -- -e '|frame_(\\d+)|\\1|' + """.format(', '.join(builtins)), formatter_class=MultilineFormatter) - parser.add_argument('other_project_dir', - help="Directory of the second project to be compared") - parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, - help="Directory to save comparison results (default: do not save)") - parser.add_argument('-v', '--visualizer', - default=DiffVisualizer.DEFAULT_FORMAT.name, - choices=[f.name for f in DiffVisualizer.OutputFormat], - help="Output format (default: %(default)s)") - parser.add_argument('--iou-thresh', default=0.5, type=float, - help="IoU match threshold for detections (default: %(default)s)") - parser.add_argument('--conf-thresh', default=0.5, type=float, - help="Confidence threshold for detections (default: %(default)s)") + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # workaround for -- eaten by positionals + parser.add_argument('target', nargs='?', default='project', + help="Project target to apply transform to (default: all)") + parser.add_argument('-t', '--transform', required=True, + help="Transform to apply to the project") + parser.add_argument('-o', '--output-dir', dest='dst_dir', + help=""" + Output directory. Can be omitted for data source targets + (i.e. not intermediate stages) and the 'project' target, + in which case the results will be saved inplace in the + working tree. + """) parser.add_argument('--overwrite', action='store_true', help="Overwrite existing files in the save directory") parser.add_argument('-p', '--project', dest='project_dir', default='.', - help="Directory of the first project to be compared (default: current dir)") - parser.set_defaults(command=diff_command) + help="Directory of the project to operate on (default: current dir)") + parser.add_argument('--stage', type=str_to_bool, default=True, + help=""" + Include this action as a project build step. + If true, this operation will be saved in the project + build tree, allowing to reproduce the resulting dataset later. + Applicable only to data source targets (i.e. not intermediate + stages) and the 'project' target (default: %(default)s) + """) + parser.add_argument('--apply', type=str_to_bool, default=True, + help="Run this command immediately. If disabled, only the " + "build tree stage will be written (default: %(default)s)") + parser.add_argument('extra_args', nargs=argparse.REMAINDER, + help="Additional arguments for transformation (pass '-- -h' for help). " + "Must be specified after the main command arguments and after " + "the '--' separator") + parser.set_defaults(command=transform_command) return parser @scoped -def diff_command(args): - first_project = load_project(args.project_dir) - second_project = load_project(args.other_project_dir) - - comparator = DistanceComparator(iou_threshold=args.iou_thresh) - - dst_dir = args.dst_dir - if dst_dir: - if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): - raise CliException("Directory '%s' already exists " - "(pass --overwrite to overwrite)" % dst_dir) +def transform_command(args): + has_sep = '--' in args._positionals + if has_sep: + pos = args._positionals.index('--') + if 1 < pos: + raise argparse.ArgumentError(None, + message="Expected no more than 1 target argument") else: - dst_dir = generate_next_file_name('%s-%s-diff' % ( - first_project.config.project_name, - second_project.config.project_name) - ) - dst_dir = osp.abspath(dst_dir) - log.info("Saving diff to '%s'" % dst_dir) - - if not osp.exists(dst_dir): - on_error_do(shutil.rmtree, dst_dir, ignore_errors=True) + pos = 1 + args.target = (args._positionals[:pos] or \ + [ProjectBuildTargets.MAIN_TARGET])[0] + args.extra_args = args._positionals[pos + has_sep:] - with DiffVisualizer(save_dir=dst_dir, comparator=comparator, - output_format=args.visualizer) as visualizer: - visualizer.save( - first_project.make_dataset(), - second_project.make_dataset()) + show_plugin_help = '-h' in args.extra_args or '--help' in args.extra_args - return 0 - -_ediff_default_if = ['id', 'group'] # avoid https://bugs.python.org/issue16399 - -def build_ediff_parser(parser_ctor=argparse.ArgumentParser): - parser = parser_ctor(help="Compare projects for equality", - description=""" - Compares two projects for equality.|n - |n - Examples:|n - - Compare two projects, exclude annotation group |n - |s|s|sand the 'is_crowd' attribute from comparison:|n - |s|sediff other/project/ -if group -ia is_crowd - """, - formatter_class=MultilineFormatter) - - parser.add_argument('other_project_dir', - help="Directory of the second project to be compared") - parser.add_argument('-iia', '--ignore-item-attr', action='append', - help="Ignore item attribute (repeatable)") - parser.add_argument('-ia', '--ignore-attr', action='append', - help="Ignore annotation attribute (repeatable)") - parser.add_argument('-if', '--ignore-field', action='append', - help="Ignore annotation field (repeatable, default: %s)" % \ - _ediff_default_if) - parser.add_argument('--match-images', action='store_true', - help='Match dataset items by images instead of ids') - parser.add_argument('--all', action='store_true', - help="Include matches in the output") - parser.add_argument('-p', '--project', dest='project_dir', default='.', - help="Directory of the first project to be compared (default: current dir)") - parser.set_defaults(command=ediff_command) - - return parser - -def ediff_command(args): - first_project = load_project(args.project_dir) - second_project = load_project(args.other_project_dir) - - if args.ignore_field: - args.ignore_field = _ediff_default_if - comparator = ExactComparator( - match_images=args.match_images, - ignored_fields=args.ignore_field, - ignored_attrs=args.ignore_attr, - ignored_item_attrs=args.ignore_item_attr) - matches, mismatches, a_extra, b_extra, errors = \ - comparator.compare_datasets( - first_project.make_dataset(), second_project.make_dataset()) - output = { - "mismatches": mismatches, - "a_extra_items": sorted(a_extra), - "b_extra_items": sorted(b_extra), - "errors": errors, - } - if args.all: - output["matches"] = matches - - output_file = generate_next_file_name('diff', ext='.json') - with open(output_file, 'w', encoding='utf-8') as f: - json.dump(output, f, indent=4, sort_keys=True) - - print("Found:") - print("The first project has %s unmatched items" % len(a_extra)) - print("The second project has %s unmatched items" % len(b_extra)) - print("%s item conflicts" % len(errors)) - print("%s matching annotations" % len(matches)) - print("%s mismatching annotations" % len(mismatches)) - - log.info("Output has been saved to '%s'" % output_file) - - return 0 - -def build_transform_parser(parser_ctor=argparse.ArgumentParser): - builtins = sorted(Environment().transforms.items) + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if not show_plugin_help and args.project_dir: + raise - parser = parser_ctor(help="Transform project", - description=""" - Applies some operation to dataset items in the project - and produces a new project.|n - |n - Builtin transforms: %s|n - |n - Examples:|n - - Convert instance polygons to masks:|n - |s|stransform -t polygons_to_masks - """ % ', '.join(builtins), - formatter_class=MultilineFormatter) + if project is not None: + env = project.env + else: + env = Environment() - parser.add_argument('-t', '--transform', required=True, - help="Transform to apply to the project") - parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, - help="Directory to save output (default: current dir)") - parser.add_argument('--overwrite', action='store_true', - help="Overwrite existing files in the save directory") - parser.add_argument('-p', '--project', dest='project_dir', default='.', - help="Directory of the project to operate on (default: current dir)") - parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, - help="Additional arguments for transformation (pass '-- -h' for help)") - parser.set_defaults(command=transform_command) + try: + transform = env.transforms[args.transform] + except KeyError: + raise CliException("Transform '%s' is not found" % args.transform) - return parser + extra_args = transform.parse_cmdline(args.extra_args) -def transform_command(args): - project = load_project(args.project_dir) + # TODO: check if we can accept a dataset revpath here + if args.stage and args.target not in project.working_tree.build_targets: + raise CliException("Adding a stage is only allowed for " + "source and 'project' targets, not '%s'" % args.target) dst_dir = args.dst_dir if dst_dir: if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): raise CliException("Directory '%s' already exists " "(pass --overwrite to overwrite)" % dst_dir) + dst_dir = osp.abspath(dst_dir) + + if args.target == ProjectBuildTargets.MAIN_TARGET: + targets = list(project.working_tree.sources) else: - dst_dir = generate_next_file_name('%s-%s' % \ - (project.config.project_name, make_file_name(args.transform))) - dst_dir = osp.abspath(dst_dir) + targets = [args.target] - try: - transform = project.env.transforms[args.transform] - except KeyError: - raise CliException("Transform '%s' is not found" % args.transform) + for target in targets: + project.working_tree.build_targets.add_transform_stage(target, + args.transform, params=extra_args) - extra_args = {} - if hasattr(transform, 'parse_cmdline'): - extra_args = transform.parse_cmdline(args.extra_args) + if args.apply: + log.info("Transforming...") - log.info("Loading the project...") - dataset = project.make_dataset() + if args.dst_dir: + dataset = project.working_tree.make_dataset(args.target) + dataset.save(dst_dir, save_images=True) + + log.info("Results have been saved to '%s'" % dst_dir) + else: + for target in targets: + dataset = project.working_tree.make_dataset(target) + + # Source might be missing in the working dir, so we specify + # the output directory + # We specify save_images here as a heuristic. It can probably + # be improved by checking if there are images in the dataset + # directory. + dataset.save(project.source_data_dir(target), save_images=True) - log.info("Transforming the project...") - dataset.transform_project( - method=transform, - save_dir=dst_dir, - **extra_args - ) + log.info("Finished") - log.info("Transform results have been saved to '%s'" % dst_dir) + if args.stage: + for target_name in targets: + project.refresh_source_hash(target_name) + project.working_tree.save() return 0 def build_stats_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Get project statistics", description=""" - Outputs various project statistics like image mean and std, - annotations count etc. + Outputs various project statistics like image mean and std, + annotations count etc.|n + |n + Target dataset is specified by a revpath. The full syntax is:|n + - Dataset paths:|n + |s|s- [ : ]|n + - Revision paths:|n + |s|s- [ @ ] [ : ]|n + |s|s- [ : ]|n + |s|s- |n + |n + Both forms use the -p/--project as a context for plugins. It can be + useful for dataset paths in targets. When not specified, the current + project's working tree is used.|n + |n + Examples:|n + - Compute project statistics:|n + |s|s%(prog)s """, formatter_class=MultilineFormatter) + parser.add_argument('target', default='project', nargs='?', + help="Target dataset revpath (default: project)") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") parser.set_defaults(command=stats_command) return parser +@scoped def stats_command(args): - project = load_project(args.project_dir) + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if args.project_dir: + raise + + dataset, target_project = parse_full_revpath(args.target, project) + if target_project: + scope_add(target_project) - dataset = project.make_dataset() stats = {} stats.update(compute_image_statistics(dataset)) stats.update(compute_ann_statistics(dataset)) @@ -724,123 +552,157 @@ def stats_command(args): def build_info_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Get project info", description=""" - Outputs project info. + Outputs project info - information about plugins, + sources, build tree, models and revisions.|n + |n + Examples:|n + - Print project info for the current working tree:|n + |s|s%(prog)s|n + |n + - Print project info for the previous revision:|n + |s|s%(prog)s HEAD~1 """, formatter_class=MultilineFormatter) - parser.add_argument('--all', action='store_true', - help="Print all information") + parser.add_argument('revision', default='', nargs='?', + help="Target revision (default: current working tree)") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") parser.set_defaults(command=info_command) return parser +@scoped def info_command(args): - project = load_project(args.project_dir) - config = project.config - env = project.env - dataset = project.make_dataset() + project = scope_add(load_project(args.project_dir)) + rev = project.get_rev(args.revision) + env = rev.env print("Project:") - print(" name:", config.project_name) - print(" location:", config.project_dir) + print(" location:", project._root_dir) print("Plugins:") - print(" importers:", ', '.join(env.importers.items)) - print(" extractors:", ', '.join(env.extractors.items)) - print(" converters:", ', '.join(env.converters.items)) - print(" launchers:", ', '.join(env.launchers.items)) - - print("Sources:") - for source_name, source in config.sources.items(): - print(" source '%s':" % source_name) - print(" format:", source.format) - print(" url:", source.url) - print(" location:", project.local_source_dir(source_name)) - - def print_extractor_info(extractor, indent=''): - print("%slength:" % indent, len(extractor)) - - categories = extractor.categories() - print("%scategories:" % indent, ', '.join(c.name for c in categories)) - - for cat_type, cat in categories.items(): - print("%s %s:" % (indent, cat_type.name)) - if cat_type == AnnotationType.label: - print("%s count:" % indent, len(cat.items)) - - count_threshold = 10 - if args.all: - count_threshold = len(cat.items) - labels = ', '.join(c.name for c in cat.items[:count_threshold]) - if count_threshold < len(cat.items): - labels += " (and %s more)" % ( - len(cat.items) - count_threshold) - print("%s labels:" % indent, labels) - - print("Dataset:") - print_extractor_info(dataset, indent=" ") - - subsets = dataset.subsets() - print(" subsets:", ', '.join(subsets)) - for subset_name in subsets: - subset = dataset.get_subset(subset_name) - print(" subset '%s':" % subset_name) - print_extractor_info(subset, indent=" ") + print(" extractors:", ', '.join( + sorted(set(env.extractors) | set(env.importers)))) + print(" converters:", ', '.join(env.converters)) + print(" launchers:", ', '.join(env.launchers)) print("Models:") - for model_name, model in config.models.items(): + for model_name, model in project.models.items(): print(" model '%s':" % model_name) print(" type:", model.launcher) + print("Sources:") + for source_name, source in rev.sources.items(): + print(" '%s':" % source_name) + print(" format:", source.format) + print(" url:", source.url) + print(" location:", + osp.join(project.source_data_dir(source_name), source.path)) + print(" options:", source.options) + print(" hash:", source.hash) + + print(" stages:") + for stage in rev.build_targets[source_name].stages: + print(" '%s':" % stage.name) + print(" type:", stage.type) + print(" hash:", stage.hash) + if stage.kind: + print(" kind:", stage.kind) + if stage.params: + print(" parameters:", stage.params) + return 0 def build_validate_parser(parser_ctor=argparse.ArgumentParser): - def _parse_task_type(s): - try: - return TaskType[s.lower()].name - except: - raise argparse.ArgumentTypeError("Unknown task type %s. Expected " - "one of: %s" % (s, ', '.join(t.name for t in TaskType))) - - parser = parser_ctor(help="Validate project", description=""" - Validates project based on specified task type and stores - results like statistics, reports and summary in JSON file. + Validates a dataset according to the task type and + reports summary in a JSON file.|n + Target dataset is specified by a revpath. The full syntax is:|n + - Dataset paths:|n + |s|s- [ : ]|n + - Revision paths:|n + |s|s- [ @ ] [ : ]|n + |s|s- [ : ]|n + |s|s- |n + |n + Both forms use the -p/--project as a context for plugins. It can be + useful for dataset paths in targets. When not specified, the current + project's working tree is used.|n + |n + Examples:|n + - Validate a project's subset as a classification dataset:|n + |s|s%(prog)s -t classification -s train """, formatter_class=MultilineFormatter) - parser.add_argument('-t', '--task_type', type=_parse_task_type, - help="Task type for validation, one of %s" % \ - ', '.join(t.name for t in TaskType)) - parser.add_argument('-s', '--subset', dest='subset_name', default=None, + task_types = ', '.join(t.name for t in TaskType) + def _parse_task_type(s): + try: + return TaskType[s.lower()].name + except: + raise argparse.ArgumentTypeError("Unknown task type %s. Expected " + "one of: %s" % (s, task_types)) + + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # workaround for -- eaten by positionals + parser.add_argument('target', default='project', nargs='?', + help="Target dataset revpath (default: project)") + parser.add_argument('-t', '--task', + type=_parse_task_type, required=True, + help="Task type for validation, one of %s" % task_types) + parser.add_argument('-s', '--subset', dest='subset_name', help="Subset to validate (default: whole dataset)") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to validate (default: current dir)") - parser.add_argument('extra_args', nargs=argparse.REMAINDER, default=None, + parser.add_argument('extra_args', nargs=argparse.REMAINDER, help="Optional arguments for validator (pass '-- -h' for help)") parser.set_defaults(command=validate_command) return parser +@scoped def validate_command(args): - project = load_project(args.project_dir) - dst_file_name = f'report-{args.task_type}' + has_sep = '--' in args._positionals + if has_sep: + pos = args._positionals.index('--') + if 1 < pos: + raise argparse.ArgumentError(None, + message="Expected no more than 1 target argument") + else: + pos = 1 + args.target = (args._positionals[:pos] or ['project'])[0] + args.extra_args = args._positionals[pos + has_sep:] - dataset = project.make_dataset() - if args.subset_name is not None: - dataset = dataset.get_subset(args.subset_name) - dst_file_name += f'-{args.subset_name}' + show_plugin_help = '-h' in args.extra_args or '--help' in args.extra_args + + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if not show_plugin_help and args.project_dir: + raise + + if project is not None: + env = project.env + else: + env = Environment() try: - validator_type = project.env.validators[args.task_type] + validator_type = env.validators[args.task] except KeyError: - raise CliException("Validator type '%s' is not found" % args.task_type) + raise CliException("Validator type '%s' is not found" % args.task) - extra_args = {} - if hasattr(validator_type, 'parse_cmdline'): - extra_args = validator_type.parse_cmdline(args.extra_args) + extra_args = validator_type.parse_cmdline(args.extra_args) + + dataset, target_project = parse_full_revpath(args.target, project) + if target_project: + scope_add(target_project) + + dst_file_name = f'validation-report' + if args.subset_name is not None: + dataset = dataset.get_subset(args.subset_name) + dst_file_name += f'-{args.subset_name}' validator = validator_type(**extra_args) report = validator.validate(dataset) @@ -866,6 +728,49 @@ def _make_serializable(d): json.dump(report, f, indent=4, sort_keys=True, default=numpy_encoder) +def build_migrate_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Migrate project", + description=""" + Migrates the project from the old version to a new one.|n + |n + Examples: + - Migrate a project from v1 to v2, save the new project in other dir:|n + |s|s%(prog)s -o + """, + formatter_class=MultilineFormatter) + + parser.add_argument('-o', '--output-dir', dest='dst_dir', required=True, + help="Output directory for the updated project") + parser.add_argument('-f', '--force', action='store_true', + help="Ignore source import errors (default: %(default)s)") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to validate (default: current dir)") + parser.add_argument('--overwrite', action='store_true', + help="Overwrite existing files in the save directory") + parser.set_defaults(command=migrate_command) + + return parser + +@scoped +def migrate_command(args): + dst_dir = args.dst_dir + if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir): + raise CliException("Directory '%s' already exists " + "(pass --overwrite to overwrite)" % dst_dir) + dst_dir = osp.abspath(dst_dir) + + log.debug("Migrating project from v1 to v2...") + + try: + Project.migrate_from_v1_to_v2(args.project_dir, dst_dir, + skip_import_errors=args.force) + except Exception as e: + raise MigrationError("Failed to migrate the project " + "automatically. Try to create a new project and " + "add sources manually with 'datum create' and 'datum add'.") from e + + log.debug("Finished") + def build_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor( description=""" @@ -878,16 +783,12 @@ def build_parser(parser_ctor=argparse.ArgumentParser): formatter_class=MultilineFormatter) subparsers = parser.add_subparsers() - add_subparser(subparsers, 'create', build_create_parser) - add_subparser(subparsers, 'import', build_import_parser) add_subparser(subparsers, 'export', build_export_parser) add_subparser(subparsers, 'filter', build_filter_parser) - add_subparser(subparsers, 'merge', build_merge_parser) - add_subparser(subparsers, 'diff', build_diff_parser) - add_subparser(subparsers, 'ediff', build_ediff_parser) add_subparser(subparsers, 'transform', build_transform_parser) add_subparser(subparsers, 'info', build_info_parser) add_subparser(subparsers, 'stats', build_stats_parser) add_subparser(subparsers, 'validate', build_validate_parser) + add_subparser(subparsers, 'migrate', build_migrate_parser) return parser diff --git a/datumaro/cli/contexts/source.py b/datumaro/cli/contexts/source.py index 54863340fe..d9a51ed1b7 100644 --- a/datumaro/cli/contexts/source.py +++ b/datumaro/cli/contexts/source.py @@ -5,230 +5,181 @@ import argparse import logging as log import os -import os.path as osp -import shutil +from datumaro.components.errors import ProjectNotFoundError from datumaro.components.project import Environment +from datumaro.util.scope import on_error_do, scope_add, scoped -from ..util import CliException, MultilineFormatter, add_subparser -from ..util.project import load_project +from ..util import MultilineFormatter, add_subparser, join_cli_args +from ..util.errors import CliException +from ..util.project import generate_next_name, load_project def build_add_parser(parser_ctor=argparse.ArgumentParser): - builtins = sorted(Environment().extractors.items) - - base_parser = argparse.ArgumentParser(add_help=False) - base_parser.add_argument('-n', '--name', default=None, - help="Name of the new source") - base_parser.add_argument('-f', '--format', required=True, - help="Source dataset format") - base_parser.add_argument('--skip-check', action='store_true', - help="Skip source checking") - base_parser.add_argument('-p', '--project', dest='project_dir', default='.', - help="Directory of the project to operate on (default: current dir)") + env = Environment() + builtins = sorted(set(env.extractors) | set(env.importers)) parser = parser_ctor(help="Add data source to project", description=""" - Adds a data source to a project. The source can be:|n - - a dataset in a supported format (check 'formats' section below)|n - - a Datumaro project|n - |n - The source can be either a local directory or a remote - git repository. Each source type has its own parameters, which can - be checked by:|n - '%s'.|n - |n - Formats:|n - Datasets come in a wide variety of formats. Each dataset - format defines its own data structure and rules on how to - interpret the data. For example, the following data structure - is used in COCO format:|n - /dataset/|n - - /images/.jpg|n - - /annotations/|n - |n - In Datumaro dataset formats are supported by Extractor-s. - An Extractor produces a list of dataset items corresponding - to the dataset. It is possible to add a custom Extractor. - To do this, you need to put an Extractor - definition script to /.datumaro/extractors.|n - |n - List of builtin source formats: %s|n - |n - Examples:|n - - Add a local directory with VOC-like dataset:|n - |s|sadd path path/to/voc -f voc_detection|n - - Add a local file with CVAT annotations, call it 'mysource'|n - |s|s|s|sto the project somewhere else:|n - |s|sadd path path/to/cvat.xml -f cvat -n mysource -p somewhere/else/ - """ % ('%(prog)s SOURCE_TYPE --help', ', '.join(builtins)), - formatter_class=MultilineFormatter, - add_help=False) - parser.set_defaults(command=add_command) + Adds a data source to a project. A data source is a dataset + in a supported format (check 'formats' section below).|n + |n + Currently, only local paths to sources are supported.|n + Once added, a source is copied into project.|n + |n + Formats:|n + Datasets come in a wide variety of formats. Each dataset + format defines its own data structure and rules on how to + interpret the data. Check the user manual for the list of + supported formats, examples and documentation. + |n + The list of supported formats can be extended by plugins. + Check the "plugins" section of the developer guide for information + about plugin implementation.|n + |n + Each dataset format has its own import options, which are passed + after the '--' separator (see examples), pass '-- -h' for more info.|n + |n + Builtin formats: {}|n + |n + Examples:|n + - Add a local directory with a VOC-like dataset:|n + |s|s%(prog)s -f voc path/to/voc|n + |n + - Add a directory with a COCO dataset, use only a specific file:|n + |s|s%(prog)s -f coco_instances -r anns/train.json path/to/coco|n + |n + - Add a local file with CVAT annotations, call it 'mysource'|n + |s|s|s|sto the project in a specific place:|n + |s|s%(prog)s -f cvat -n mysource -p project/path/ path/to/cvat.xml + """.format(', '.join(builtins)), + formatter_class=MultilineFormatter) - sp = parser.add_subparsers(dest='source_type', metavar='SOURCE_TYPE', - help="The type of the data source " - "(call '%s SOURCE_TYPE --help' for more info)" % parser.prog) - - dir_parser = sp.add_parser('path', help="Add local path as source", - parents=[base_parser]) - dir_parser.add_argument('url', - help="Path to the source") - dir_parser.add_argument('--copy', action='store_true', - help="Copy the dataset instead of saving source links") - - repo_parser = sp.add_parser('git', help="Add git repository as source", - parents=[base_parser]) - repo_parser.add_argument('url', - help="URL of the source git repository") - repo_parser.add_argument('-b', '--branch', default='master', - help="Branch of the source repository (default: %(default)s)") - repo_parser.add_argument('--checkout', action='store_true', - help="Do branch checkout") - - # NOTE: add common parameters to the parent help output - # the other way could be to use parse_known_args() - display_parser = argparse.ArgumentParser( - parents=[base_parser, parser], - prog=parser.prog, usage="%(prog)s [-h] SOURCE_TYPE ...", - description=parser.description, formatter_class=MultilineFormatter) - class HelpAction(argparse._HelpAction): - def __call__(self, parser, namespace, values, option_string=None): - display_parser.print_help() - parser.exit() - - parser.add_argument('-h', '--help', action=HelpAction, - help='show this help message and exit') - - # TODO: needed distinction on how to add an extractor or a remote source + parser.add_argument('_positionals', nargs=argparse.REMAINDER, + help=argparse.SUPPRESS) # workaround for -- eaten by positionals + parser.add_argument('url', + help="URL to the source dataset. A path to a file or directory") + parser.add_argument('-n', '--name', + help="Name of the new source (default: generate automatically)") + parser.add_argument('-f', '--format', required=True, + help="Source dataset format") + parser.add_argument('-r', '--path', + help="A path relative to URL to the source data. Useful to specify " + "a path to subset, subtask, or a specific file in URL.") + parser.add_argument('--no-check', action='store_true', + help="Don't try to read the source after importing") + parser.add_argument('--no-cache', action='store_true', + help="Don't put a copy into the project cache") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the project to operate on (default: current dir)") + parser.add_argument('extra_args', nargs=argparse.REMAINDER, + help="Additional arguments for extractor (pass '-- -h' for help). " + "Must be specified after the main command arguments and after " + "the '--' separator") + parser.set_defaults(command=add_command) return parser +@scoped def add_command(args): - project = load_project(args.project_dir) + # Workaround. Required positionals consume positionals from the end + args._positionals += join_cli_args(args, 'url', 'extra_args') - if args.source_type == 'git': - name = args.name - if name is None: - name = osp.splitext(osp.basename(args.url))[0] + has_sep = '--' in args._positionals + if has_sep: + pos = args._positionals.index('--') + else: + pos = 1 + args.url = (args._positionals[:pos] or [''])[0] + args.extra_args = args._positionals[pos + has_sep:] - if project.env.git.has_submodule(name): - raise CliException("Git submodule '%s' already exists" % name) + show_plugin_help = '-h' in args.extra_args or '--help' in args.extra_args - try: - project.get_source(name) - raise CliException("Source '%s' already exists" % name) - except KeyError: - pass - - rel_local_dir = project.local_source_dir(name) - local_dir = osp.join(project.config.project_dir, rel_local_dir) - url = args.url - project.env.git.create_submodule(name, local_dir, - url=url, branch=args.branch, no_checkout=not args.checkout) - elif args.source_type == 'path': - url = osp.abspath(args.url) - if not osp.exists(url): - raise CliException("Source path '%s' does not exist" % url) - - name = args.name - if name is None: - name = osp.splitext(osp.basename(url))[0] - - if project.env.git.has_submodule(name): - raise CliException("Git submodule '%s' already exists" % name) - - try: - project.get_source(name) + project = None + try: + project = scope_add(load_project(args.project_dir)) + except ProjectNotFoundError: + if not show_plugin_help and args.project_dir: + raise + + if project is not None: + env = project.env + else: + env = Environment() + + fmt = args.format + if fmt in env.importers: + arg_parser = env.importers[fmt] + elif fmt in env.extractors: + arg_parser = env.extractors[fmt] + else: + raise CliException("Unknown format '%s'. A format can be added" + "by providing an Extractor and Importer plugins" % fmt) + + extra_args = arg_parser.parse_cmdline(args.extra_args) + + name = args.name + if name: + if name in project.working_tree.sources: raise CliException("Source '%s' already exists" % name) - except KeyError: - pass - - rel_local_dir = project.local_source_dir(name) - local_dir = osp.join(project.config.project_dir, rel_local_dir) - - if args.copy: - log.info("Copying from '%s' to '%s'" % (url, local_dir)) - if osp.isdir(url): - # copytree requires destination dir not to exist - shutil.copytree(url, local_dir) - url = rel_local_dir - elif osp.isfile(url): - os.makedirs(local_dir) - shutil.copy2(url, local_dir) - url = osp.join(rel_local_dir, osp.basename(url)) - else: - raise Exception("Expected file or directory") - else: - os.makedirs(local_dir) - - project.add_source(name, { 'url': url, 'format': args.format }) - - if not args.skip_check: + else: + name = generate_next_name( + list(project.working_tree.sources) + os.listdir(), + 'source', sep='-', default='1') + + project.import_source(name, url=args.url, format=args.format, + options=extra_args, no_cache=args.no_cache, rpath=args.path) + on_error_do(project.remove_source, name, ignore_errors=True, + kwargs={'force': True, 'keep_data': False}) + + if not args.no_check: log.info("Checking the source...") - try: - project.make_source_project(name).make_dataset() - except Exception: - shutil.rmtree(local_dir, ignore_errors=True) - raise + project.working_tree.make_dataset(name) - project.save() + project.working_tree.save() - log.info("Source '%s' has been added to the project, location: '%s'" \ - % (name, rel_local_dir)) + log.info("Source '%s' with format '%s' has been added to the project", + name, args.format) return 0 def build_remove_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Remove source from project", - description="Remove a source from a project.") + description="Remove a source from a project") - parser.add_argument('-n', '--name', required=True, - help="Name of the source to be removed") + parser.add_argument('names', nargs='+', + help="Names of the sources to be removed") parser.add_argument('--force', action='store_true', - help="Ignore possible errors during removal") + help="Do not fail and stop on errors during removal") parser.add_argument('--keep-data', action='store_true', - help="Do not remove source data") + help="Do not remove source data from the working directory, remove " + "only project metainfo.") parser.add_argument('-p', '--project', dest='project_dir', default='.', help="Directory of the project to operate on (default: current dir)") parser.set_defaults(command=remove_command) return parser +@scoped def remove_command(args): - project = load_project(args.project_dir) + project = scope_add(load_project(args.project_dir)) - name = args.name - if not name: + if not args.names: raise CliException("Expected source name") - try: - project.get_source(name) - except KeyError: - if not args.force: - raise CliException("Source '%s' does not exist" % name) - if project.env.git.has_submodule(name): - if args.force: - log.warning("Forcefully removing the '%s' source..." % name) + for name in args.names: + project.remove_source(name, force=args.force, keep_data=args.keep_data) + project.working_tree.save() - project.env.git.remove_submodule(name, force=args.force) - - source_dir = osp.join(project.config.project_dir, - project.local_source_dir(name)) - project.remove_source(name) - project.save() - - if not args.keep_data: - shutil.rmtree(source_dir, ignore_errors=True) - - log.info("Source '%s' has been removed from the project" % name) + log.info("Sources '%s' have been removed from the project" % \ + ', '.join(args.names)) return 0 def build_info_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor() - parser.add_argument('-n', '--name', + parser.add_argument('name', nargs='?', help="Source name") parser.add_argument('-v', '--verbose', action='store_true', help="Show details") @@ -238,17 +189,18 @@ def build_info_parser(parser_ctor=argparse.ArgumentParser): return parser +@scoped def info_command(args): - project = load_project(args.project_dir) + project = scope_add(load_project(args.project_dir)) if args.name: - source = project.get_source(args.name) + source = project.working_tree.sources[args.name] print(source) else: - for name, conf in project.config.sources.items(): + for name, conf in project.working_tree.sources.items(): print(name) if args.verbose: - print(dict(conf)) + print(conf) def build_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(description=""" diff --git a/datumaro/cli/util/__init__.py b/datumaro/cli/util/__init__.py index 156bbd6cf3..f6a41b9874 100644 --- a/datumaro/cli/util/__init__.py +++ b/datumaro/cli/util/__init__.py @@ -2,13 +2,10 @@ # # SPDX-License-Identifier: MIT +from typing import Iterable, List import argparse import textwrap -from datumaro.components.errors import DatumaroError - - -class CliException(DatumaroError): pass def add_subparser(subparsers, name, builder): return builder(lambda **kwargs: subparsers.add_parser(name, **kwargs)) @@ -59,3 +56,16 @@ def __call__(self, parser, args, values, option_string=None): def at_least(n): return required_count(n, 0) + +def join_cli_args(args: argparse.Namespace, *names: Iterable[str]) -> List: + "Merges arg values in a list" + + joined = [] + + for name in names: + value = getattr(args, name) + if not isinstance(value, list): + value = [value] + joined += value + + return joined diff --git a/datumaro/cli/util/errors.py b/datumaro/cli/util/errors.py new file mode 100644 index 0000000000..40778f5f79 --- /dev/null +++ b/datumaro/cli/util/errors.py @@ -0,0 +1,18 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from attr import attrib, attrs + +from datumaro.components.errors import DatumaroError + + +class CliException(DatumaroError): pass + +@attrs +class WrongRevpathError(CliException): + problems = attrib() + + def __str__(self): + return "Failed to parse revspec:\n " + \ + '\n '.join(str(p) for p in self.problems) diff --git a/datumaro/cli/util/project.py b/datumaro/cli/util/project.py index d3b5b56944..11117d5c4f 100644 --- a/datumaro/cli/util/project.py +++ b/datumaro/cli/util/project.py @@ -2,15 +2,21 @@ # # SPDX-License-Identifier: MIT +from typing import Optional, Tuple import os import re -from datumaro.components.project import Project -from datumaro.util import cast +from datumaro.cli.util.errors import WrongRevpathError +from datumaro.components.dataset import Dataset +from datumaro.components.environment import Environment +from datumaro.components.errors import DatumaroError, ProjectNotFoundError +from datumaro.components.project import Project, Revision +from datumaro.util.os_util import generate_next_name +from datumaro.util.scope import on_error_do, scoped -def load_project(project_dir): - return Project.load(project_dir) +def load_project(project_dir, readonly=False): + return Project(project_dir, readonly=readonly) def generate_next_file_name(basename, basedir='.', sep='.', ext=''): """ @@ -22,17 +28,127 @@ def generate_next_file_name(basename, basedir='.', sep='.', ext=''): return generate_next_name(os.listdir(basedir), basename, sep, ext) -def generate_next_name(names, basename, sep='.', suffix='', default=None): - pattern = re.compile(r'%s(?:%s(\d+))?%s' % \ - tuple(map(re.escape, [basename, sep, suffix]))) - matches = [match for match in (pattern.match(n) for n in names) if match] - - max_idx = max([cast(match[1], int, 0) for match in matches], default=None) - if max_idx is None: - if default is not None: - idx = sep + str(default) - else: - idx = '' +def parse_dataset_pathspec(s: str, + env: Optional[Environment] = None) -> Dataset: + """ + Parses Dataset paths. The syntax is: + - [ : ] + + Returns: a dataset from the parsed path + """ + + match = re.fullmatch(r""" + (?P(?: [^:] | :[/\\] )+) + (:(?P.+))? + """, s, flags=re.VERBOSE) + if not match: + raise ValueError("Failed to recognize dataset pathspec in '%s'" % s) + match = match.groupdict() + + path = match["dataset_path"] + format = match["format"] + return Dataset.import_from(path, format, env=env) + +@scoped +def parse_revspec(s: str, ctx_project: Optional[Project] = None) \ + -> Tuple[Dataset, Project]: + """ + Parses Revision paths. The syntax is: + - [ @ ] [ : ] + - [ : ] + - + The second and the third forms assume an existing "current" project. + + Returns: the dataset and the project from the parsed path. + The project is only returned when specified in the revpath. + """ + + match = re.fullmatch(r""" + (?P(?: [^@:] | :[/\\] )+) + (@(?P[^:]+))? + (:(?P.+))? + """, s, flags=re.VERBOSE) + if not match: + raise ValueError("Failed to recognize revspec in '%s'" % s) + match = match.groupdict() + + proj_path = match["proj_path"] + rev = match["rev"] + source = match["source"] + + target_project = None + + assert proj_path + if rev: + target_project = load_project(proj_path, readonly=True) + project = target_project + # proj_path is either proj_path or rev or source name + elif Project.find_project_dir(proj_path): + target_project = load_project(proj_path, readonly=True) + project = target_project + elif ctx_project: + project = ctx_project + if project.is_ref(proj_path): + rev = proj_path + elif not source: + source = proj_path + else: - idx = sep + str(max_idx + 1) - return basename + idx + suffix + raise ProjectNotFoundError("Failed to find project at '%s'. " \ + "Specify project path with '-p/--project' or in the " + "target pathspec." % proj_path) + + if target_project: + on_error_do(Project.close, target_project, ignore_errors=True) + + tree = project.get_rev(rev) + return tree.make_dataset(source), target_project + +def parse_full_revpath(s: str, ctx_project: Optional[Project] = None) \ + -> Tuple[Dataset, Optional[Project]]: + """ + revpath - either a Dataset path or a Revision path. + + Returns: the dataset and the project from the parsed path + The project is only returned when specified in the revpath. + """ + + if ctx_project: + env = ctx_project.env + else: + env = Environment() + + errors = [] + try: + return parse_dataset_pathspec(s, env=env), None + except (DatumaroError, OSError) as e: + errors.append(e) + + try: + return parse_revspec(s, ctx_project=ctx_project) + except (DatumaroError, OSError) as e: + errors.append(e) + + raise WrongRevpathError(problems=errors) + +def split_local_revpath(revpath: str) -> Tuple[Revision, str]: + """ + Splits the given string into revpath components. + + A local revpath is a path to a revision withing the current project. + The syntax is: + - [ : ] [ ] + At least one part must be present. + + Returns: (revision, build target) + """ + + sep_pos = revpath.find(':') + if -1 < sep_pos: + rev = revpath[:sep_pos] + target = revpath[sep_pos + 1:] + else: + rev = '' + target = revpath + + return rev, target diff --git a/datumaro/components/cli_plugin.py b/datumaro/components/cli_plugin.py index a884bf663c..3d5b1a3d46 100644 --- a/datumaro/components/cli_plugin.py +++ b/datumaro/components/cli_plugin.py @@ -2,12 +2,32 @@ # # SPDX-License-Identifier: MIT +from typing import List, Type import argparse import logging as log from datumaro.cli.util import MultilineFormatter from datumaro.util import to_snake_case +_plugin_types = None +def plugin_types() -> List[Type['CliPlugin']]: + global _plugin_types + if _plugin_types is None: + from datumaro.components.converter import Converter + from datumaro.components.extractor import Extractor, Importer, Transform + from datumaro.components.launcher import Launcher + from datumaro.components.validator import Validator + + _plugin_types = [Launcher, Extractor, Transform, Importer, + Converter, Validator] + + return _plugin_types + +def remove_plugin_type(s): + for t in {'transform', 'extractor', 'converter', 'launcher', 'importer', + 'validator'}: + s = s.replace('_' + t, '') + return s class CliPlugin: @staticmethod @@ -19,14 +39,7 @@ def _get_name(cls): def _get_doc(cls): doc = getattr(cls, '__doc__', "") if doc: - from datumaro.components.converter import Converter - from datumaro.components.extractor import ( - Extractor, Importer, Transform, - ) - from datumaro.components.launcher import Launcher - base_classes = [Launcher, Extractor, Transform, Importer, Converter] - - if any(getattr(t, '__doc__', '') == doc for t in base_classes): + if any(getattr(t, '__doc__', '') == doc for t in plugin_types()): doc = '' return doc @@ -53,10 +66,3 @@ def parse_cmdline(cls, args=None): '\n\t'.join('%s: %s' % (k, v) for k, v in args.items())) return args - -def remove_plugin_type(s): - for t in { - 'transform', 'extractor', 'converter', 'launcher', 'importer', 'validator', - }: - s = s.replace('_' + t, '') - return s diff --git a/datumaro/components/config.py b/datumaro/components/config.py index a78bd7757e..d4ed90981d 100644 --- a/datumaro/components/config.py +++ b/datumaro/components/config.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019-2020 Intel Corporation +# Copyright (C) 2019-2021 Intel Corporation # # SPDX-License-Identifier: MIT diff --git a/datumaro/components/config_model.py b/datumaro/components/config_model.py index 5d85df0cf3..786bf2635e 100644 --- a/datumaro/components/config_model.py +++ b/datumaro/components/config_model.py @@ -1,21 +1,28 @@ -# Copyright (C) 2019-2020 Intel Corporation +# Copyright (C) 2019-2021 Intel Corporation # # SPDX-License-Identifier: MIT from datumaro.components.config import Config from datumaro.components.config import DictConfig as _DictConfig from datumaro.components.config import SchemaBuilder as _SchemaBuilder +from datumaro.util import find SOURCE_SCHEMA = _SchemaBuilder() \ .add('url', str) \ + .add('path', str) \ .add('format', str) \ .add('options', dict) \ + .add('hash', str) \ .build() class Source(Config): def __init__(self, config=None): super().__init__(config, schema=SOURCE_SCHEMA) + @property + def is_generated(self) -> bool: + return not self.url + MODEL_SCHEMA = _SchemaBuilder() \ .add('launcher', str) \ @@ -27,35 +34,108 @@ def __init__(self, config=None): super().__init__(config, schema=MODEL_SCHEMA) +BUILDSTAGE_SCHEMA = _SchemaBuilder() \ + .add('name', str) \ + .add('type', str) \ + .add('kind', str) \ + .add('hash', str) \ + .add('params', dict) \ + .build() + +class BuildStage(Config): + def __init__(self, config=None): + super().__init__(config, schema=BUILDSTAGE_SCHEMA) + +BUILDTARGET_SCHEMA = _SchemaBuilder() \ + .add('stages', list) \ + .add('parents', list) \ + .build() + +class BuildTarget(Config): + def __init__(self, config=None): + super().__init__(config, schema=BUILDTARGET_SCHEMA) + self.stages = [BuildStage(o) for o in self.stages] + + @property + def root(self): + return self.stages[0] + + @property + def head(self): + return self.stages[-1] + + @property + def has_stages(self) -> bool: + return 1 < len(self.stages) + + def find_stage(self, stage): + if stage == 'root': + return self.root + elif stage == 'head': + return self.head + return find(self.stages, lambda x: x.name == stage or x == stage) + + def get_stage(self, stage): + res = self.find_stage(stage) + if res is None: + raise KeyError("Unknown stage '%s'" % stage) + return res + + +TREE_SCHEMA = _SchemaBuilder() \ + .add('format_version', int) \ + \ + .add('sources', lambda: _DictConfig(lambda v=None: Source(v))) \ + .add('build_targets', lambda: _DictConfig(lambda v=None: BuildTarget(v))) \ + \ + .add('base_dir', str, internal=True) \ + .add('config_path', str, internal=True) \ + .build() + +TREE_DEFAULT_CONFIG = Config({ + 'format_version': 2, + + 'config_path': '', +}, mutable=False, schema=TREE_SCHEMA) + +class TreeConfig(Config): + def __init__(self, config=None, mutable=True): + super().__init__(config=config, mutable=mutable, + fallback=TREE_DEFAULT_CONFIG, schema=TREE_SCHEMA) + + PROJECT_SCHEMA = _SchemaBuilder() \ - .add('project_name', str) \ .add('format_version', int) \ \ - .add('subsets', list) \ - .add('sources', lambda: _DictConfig( - lambda v=None: Source(v))) \ - .add('models', lambda: _DictConfig( - lambda v=None: Model(v))) \ + .add('models', lambda: _DictConfig(lambda v=None: Model(v))) \ \ - .add('models_dir', str, internal=True) \ - .add('plugins_dir', str, internal=True) \ - .add('sources_dir', str, internal=True) \ - .add('dataset_dir', str, internal=True) \ - .add('project_filename', str, internal=True) \ - .add('project_dir', str, internal=True) \ - .add('env_dir', str, internal=True) \ .build() PROJECT_DEFAULT_CONFIG = Config({ - 'project_name': 'undefined', - 'format_version': 1, + 'format_version': 2, +}, mutable=False, schema=PROJECT_SCHEMA) - 'sources_dir': 'sources', - 'dataset_dir': 'dataset', - 'models_dir': 'models', - 'plugins_dir': 'plugins', +class ProjectConfig(Config): + def __init__(self, config=None, mutable=True): + super().__init__(config=config, mutable=mutable, + fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA) - 'project_filename': 'config.yaml', - 'project_dir': '', - 'env_dir': '.datumaro', -}, mutable=False, schema=PROJECT_SCHEMA) + +class PipelineConfig(Config): + pass + +class ProjectLayout: + aux_dir = '.datumaro' + cache_dir = 'cache' + index_dir = 'index' + working_tree_dir = 'tree' + head_file = 'head' + tmp_dir = 'tmp' + models_dir = 'models' + plugins_dir = 'plugins' + conf_file = 'config.yml' + + +class TreeLayout: + conf_file = 'config.yml' + sources_dir = 'sources' diff --git a/datumaro/components/converter.py b/datumaro/components/converter.py index 21c0410cad..33c2d6a24b 100644 --- a/datumaro/components/converter.py +++ b/datumaro/components/converter.py @@ -13,6 +13,7 @@ from datumaro.components.dataset import DatasetPatch from datumaro.components.extractor import DatasetItem from datumaro.util.image import Image +from datumaro.util.os_util import rmtree from datumaro.util.scope import on_error_do, scoped @@ -55,12 +56,12 @@ def patch(cls, dataset, patch, save_dir, **options): tmpdir = mkdtemp(dir=osp.dirname(save_dir), prefix=osp.basename(save_dir), suffix='.tmp') - on_error_do(shutil.rmtree, tmpdir, ignore_errors=True) + on_error_do(rmtree, tmpdir, ignore_errors=True) shutil.copymode(save_dir, tmpdir) retval = cls.convert(dataset, tmpdir, **options) - shutil.rmtree(save_dir) + rmtree(save_dir) os.replace(tmpdir, save_dir) return retval diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index 06e52dd8eb..fd889b2947 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -12,7 +12,6 @@ import logging as log import os import os.path as osp -import shutil from datumaro.components.annotation import AnnotationType, LabelCategories from datumaro.components.dataset_filter import ( @@ -29,6 +28,7 @@ ) from datumaro.util import is_method_redefined from datumaro.util.log_utils import logging_disabled +from datumaro.util.os_util import rmtree from datumaro.util.scope import on_error_do, scoped DEFAULT_FORMAT = 'datumaro' @@ -802,7 +802,7 @@ def export(self, save_dir: str, format, **kwargs): save_dir = osp.abspath(save_dir) if not osp.exists(save_dir): - on_error_do(shutil.rmtree, save_dir, ignore_errors=True) + on_error_do(rmtree, save_dir, ignore_errors=True) inplace = False os.makedirs(save_dir, exist_ok=True) @@ -840,8 +840,7 @@ def import_from(cls, path: str, format: str = None, env: Environment = None, if format in env.importers: importer = env.make_importer(format) with logging_disabled(log.INFO): - project = importer(path, **kwargs) - detected_sources = list(project.config.sources.values()) + detected_sources = importer(path, **kwargs) elif format in env.extractors: detected_sources = [{ 'url': path, 'format': format, 'options': kwargs diff --git a/datumaro/components/environment.py b/datumaro/components/environment.py index 0170bf979d..68ad651d10 100644 --- a/datumaro/components/environment.py +++ b/datumaro/components/environment.py @@ -4,16 +4,13 @@ from functools import partial from glob import glob -from typing import Dict, Iterable +from typing import Iterable import inspect import logging as log import os import os.path as osp -import git - -from datumaro.components.config import Config -from datumaro.components.config_model import Model, Source +from datumaro.components.cli_plugin import CliPlugin, plugin_types from datumaro.util.os_util import import_foreign_module @@ -41,25 +38,12 @@ def __contains__(self, key): def __iter__(self): return iter(self.items) -class ModelRegistry(Registry): - def batch_register(self, items: Dict[str, Model]): - for name, model in items.items(): - self.register(name, model) - - -class SourceRegistry(Registry): - def batch_register(self, items: Dict[str, Source]): - for name, source in items.items(): - self.register(name, source) - class PluginRegistry(Registry): def __init__(self, filter=None): #pylint: disable=redefined-builtin super().__init__() self.filter = filter def batch_register(self, values: Iterable): - from datumaro.components.cli_plugin import CliPlugin - for v in values: if self.filter and not self.filter(v): continue @@ -67,62 +51,10 @@ def batch_register(self, values: Iterable): self.register(name, v) -class GitWrapper: - def __init__(self, config=None): - self.repo = None - - if config is not None and config.project_dir: - self.init(config.project_dir) - - @staticmethod - def _git_dir(base_path): - return osp.join(base_path, '.git') - - @classmethod - def spawn(cls, path): - spawn = not osp.isdir(cls._git_dir(path)) - repo = git.Repo.init(path=path) - if spawn: - repo.config_writer().set_value("user", "name", "User") \ - .set_value("user", "email", "user@nowhere.com") \ - .release() - # gitpython does not support init, use git directly - repo.git.init() - repo.git.commit('-m', 'Initial commit', '--allow-empty') - return repo - - def init(self, path): - self.repo = self.spawn(path) - return self.repo - - def is_initialized(self): - return self.repo is not None - - def create_submodule(self, name, dst_dir, **kwargs): - self.repo.create_submodule(name, dst_dir, **kwargs) - - def has_submodule(self, name): - return name in [submodule.name for submodule in self.repo.submodules] - - def remove_submodule(self, name, **kwargs): - return self.repo.submodule(name).remove(**kwargs) - - class Environment: _builtin_plugins = None - def __init__(self, config=None): - from datumaro.components.project import ( - PROJECT_DEFAULT_CONFIG, PROJECT_SCHEMA, - ) - config = Config(config, - fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA) - - self.models = ModelRegistry() - self.sources = SourceRegistry() - - self.git = GitWrapper(config) - + def __init__(self): def _filter(accept, skip=None): accept = (accept, ) if inspect.isclass(accept) else tuple(accept) skip = {skip} if inspect.isclass(skip) else set(skip or []) @@ -134,6 +66,7 @@ def _filter(accept, skip=None): Extractor, Importer, ItemTransform, SourceExtractor, Transform, ) from datumaro.components.launcher import Launcher + from datumaro.components.validator import Validator self._extractors = PluginRegistry(_filter(Extractor, skip=SourceExtractor)) self._importers = PluginRegistry(_filter(Importer)) @@ -141,6 +74,7 @@ def _filter(accept, skip=None): self._converters = PluginRegistry(_filter(Converter)) self._transforms = PluginRegistry(_filter(Transform, skip=ItemTransform)) + self._validators = PluginRegistry(_filter(Validator)) self._builtins_initialized = False def _get_plugin_registry(self, name): @@ -169,6 +103,10 @@ def converters(self) -> PluginRegistry: def transforms(self) -> PluginRegistry: return self._get_plugin_registry('_transforms') + @property + def validators(self) -> PluginRegistry: + return self._get_plugin_registry('_validators') + @staticmethod def _find_plugins(plugins_dir): plugins = [] @@ -208,15 +146,7 @@ def _import_module(cls, module_dir, module_name, types, package=None): @classmethod def _load_plugins(cls, plugins_dir, types=None): - if not types: - from datumaro.components.converter import Converter - from datumaro.components.extractor import ( - Extractor, Importer, Transform, - ) - from datumaro.components.launcher import Launcher - types = [Extractor, Converter, Importer, Launcher, Transform] - - types = tuple(types) + types = tuple(types or plugin_types()) plugins = cls._find_plugins(plugins_dir) @@ -226,11 +156,7 @@ def _load_plugins(cls, plugins_dir, types=None): exports = cls._import_module(module_dir, module_name, types, package) except Exception as e: - module_search_error = ImportError - try: - module_search_error = ModuleNotFoundError # python 3.6+ - except NameError: - pass + module_search_error = ModuleNotFoundError message = ["Failed to import module '%s': %s", module_name, e] if isinstance(e, module_search_error): @@ -273,6 +199,7 @@ def _register_plugins(self, plugins): self.launchers.batch_register(plugins) self.converters.batch_register(plugins) self.transforms.batch_register(plugins) + self.validators.batch_register(plugins) def make_extractor(self, name, *args, **kwargs): return self.extractors.get(name)(*args, **kwargs) @@ -292,12 +219,6 @@ def make_converter(self, name, *args, **kwargs): def make_transform(self, name, *args, **kwargs): return partial(self.transforms.get(name), *args, **kwargs) - def register_model(self, name, model): - self.models.register(name, model) - - def unregister_model(self, name): - self.models.unregister(name) - def is_format_known(self, name): return name in self.importers or name in self.extractors diff --git a/datumaro/components/errors.py b/datumaro/components/errors.py index 448ca35f1c..1bab7800a8 100644 --- a/datumaro/components/errors.py +++ b/datumaro/components/errors.py @@ -12,6 +12,117 @@ def __str__(self): class DatumaroError(Exception): pass + +class VcsError(DatumaroError): + pass + +class ReadonlyDatasetError(VcsError): + def __str__(self): + return "Can't update a read-only dataset" + +class ReadonlyProjectError(VcsError): + def __str__(self): + return "Can't change a read-only project" + +@attrs +class UnknownRefError(VcsError): + ref = attrib() + + def __str__(self): + return f"Can't parse ref '{self.ref}'" + +class MissingObjectError(VcsError): + pass + +class MismatchingObjectError(VcsError): + pass + +@attrs +class UnsavedChangesError(VcsError): + paths = attrib() + + def __str__(self): + return "There are some uncommitted changes: %s" % ', '.join(self.paths) + +class ForeignChangesError(VcsError): + pass + +class EmptyCommitError(VcsError): + pass + +class PathOutsideSourceError(VcsError): + pass + +class SourceUrlInsideProjectError(VcsError): + def __str__(self): + return "Source URL cannot point inside the project" + +class UnexpectedUrlError(VcsError): + pass + +class PipelineError(DatumaroError): + pass + +class InvalidPipelineError(PipelineError): + pass + +class EmptyPipelineError(InvalidPipelineError): + pass + +class MultiplePipelineHeadsError(InvalidPipelineError): + pass + +class MissingPipelineHeadError(InvalidPipelineError): + pass + +class InvalidStageError(InvalidPipelineError): + pass + +class UnknownStageError(InvalidStageError): + pass + + +class MigrationError(DatumaroError): + pass + +class OldProjectError(DatumaroError): + def __str__(self): + return """ + The project you're trying to load was + created by the old Datumaro version. Try to migrate the + project with 'datum project migrate' and then reload. + """ + + +@attrs +class ProjectNotFoundError(DatumaroError): + path = attrib() + + def __str__(self): + return f"Can't find project at '{self.path}'" + +@attrs +class ProjectAlreadyExists(DatumaroError): + path = attrib() + + def __str__(self): + return f"Can't create project: a project already exists " \ + f"at '{self.path}'" + +@attrs +class UnknownSourceError(DatumaroError): + name = attrib() + + def __str__(self): + return f"Unknown source '{self.name}'" + +@attrs +class UnknownTargetError(DatumaroError): + name = attrib() + + def __str__(self): + return f"Unknown target '{self.name}'" + @attrs class UnknownFormatError(DatumaroError): format = attrib() @@ -22,17 +133,11 @@ def __str__(self): "to the environment" @attrs -class DatasetError(DatumaroError): - item_id = attrib() - -@attrs -class RepeatedItemError(DatasetError): - def __str__(self): - return f"Item {self.item_id} is repeated in the source sequence." +class SourceExistsError(DatumaroError): + name = attrib() -class CategoriesRedefinedError(DatasetError): def __str__(self): - return "Categories can only be set once for a dataset" + return f"Source '{self.name}' already exists" class DatasetImportError(DatumaroError): @@ -59,6 +164,21 @@ def __str__(self): return "Failed to detect dataset format automatically: " \ "no matching formats found" + +@attrs +class DatasetError(DatumaroError): + item_id = attrib() + +class CategoriesRedefinedError(DatasetError): + def __str__(self): + return "Categories can only be set once for a dataset" + +@attrs +class RepeatedItemError(DatasetError): + def __str__(self): + return f"Item {self.item_id} is repeated in the source sequence." + + @attrs class DatasetQualityError(DatasetError): pass @@ -84,6 +204,7 @@ def __str__(self): "found %s, expected %s, group %s" % \ (self.item_id, self.found, self.expected, self.group) + @attrs class DatasetMergeError(DatasetError): sources = attrib(converter=set, factory=set, kw_only=True) @@ -148,6 +269,7 @@ def to_dict(self): 'severity': self.severity.name, } + @attrs class DatasetItemValidationError(DatasetValidationError): item_id = attrib() diff --git a/datumaro/components/extractor.py b/datumaro/components/extractor.py index 96ccab017d..0e18ff27bc 100644 --- a/datumaro/components/extractor.py +++ b/datumaro/components/extractor.py @@ -12,6 +12,7 @@ import numpy as np from datumaro.components.annotation import AnnotationType, Categories +from datumaro.components.cli_plugin import CliPlugin from datumaro.components.errors import DatasetNotFoundError from datumaro.util import is_method_redefined from datumaro.util.attrs_util import default_if_none, not_empty @@ -172,7 +173,7 @@ def get(self, id, subset=None): return item return None -class Extractor(ExtractorBase): +class Extractor(ExtractorBase, CliPlugin): """ A base class for user-defined and built-in extractors. Should be used in cases, where SourceExtractor is not enough, @@ -205,9 +206,12 @@ def get(self, id, subset=None): assert subset == self._subset, '%s != %s' % (subset, self._subset) return super().get(id, subset or self._subset) -class Importer: +class Importer(CliPlugin): @classmethod def detect(cls, path): + if not path or not osp.exists(path): + return False + return len(cls.find_sources_with_params(path)) != 0 @classmethod @@ -219,22 +223,21 @@ def find_sources_with_params(cls, path, **extra_params) -> List[Dict]: return cls.find_sources(path) def __call__(self, path, **extra_params): - from datumaro.components.project import Project # cyclic import - project = Project() + if not path or not osp.exists(path): + raise DatasetNotFoundError(path) - sources = self.find_sources_with_params(osp.normpath(path), **extra_params) - if len(sources) == 0: + found_sources = self.find_sources_with_params(osp.normpath(path), **extra_params) + if not found_sources: raise DatasetNotFoundError(path) - for desc in sources: + sources = [] + for desc in found_sources: params = dict(extra_params) params.update(desc.get('options', {})) desc['options'] = params + sources.append(desc) - source_name = osp.splitext(osp.basename(desc['url']))[0] - project.add_source(source_name, desc) - - return project + return sources @classmethod def _find_sources_recursive(cls, path: str, ext: Optional[str], @@ -282,7 +285,7 @@ def _find_sources_recursive(cls, path: str, ext: Optional[str], break return sources -class Transform(ExtractorBase): +class Transform(ExtractorBase, CliPlugin): """ A base class for dataset transformations that change dataset items or their annotations. diff --git a/datumaro/components/launcher.py b/datumaro/components/launcher.py index a4e98b1c76..2c88230cbb 100644 --- a/datumaro/components/launcher.py +++ b/datumaro/components/launcher.py @@ -5,12 +5,13 @@ import numpy as np from datumaro.components.annotation import AnnotationType, LabelCategories +from datumaro.components.cli_plugin import CliPlugin from datumaro.components.extractor import Transform from datumaro.util import take_by # pylint: disable=no-self-use -class Launcher: +class Launcher(CliPlugin): def __init__(self, model_dir=None): pass diff --git a/datumaro/components/project.py b/datumaro/components/project.py index 83ce607622..d67e7ca2db 100644 --- a/datumaro/components/project.py +++ b/datumaro/components/project.py @@ -2,488 +2,2387 @@ # # SPDX-License-Identifier: MIT -from collections import OrderedDict +from contextlib import ExitStack, suppress +from enum import Enum, auto +from typing import Any, Dict, Iterable, List, NewType, Optional, Tuple, Union +import json import logging as log import os import os.path as osp +import re import shutil +import tempfile +import unittest.mock + +import networkx as nx +import ruamel.yaml as yaml from datumaro.components.config import Config from datumaro.components.config_model import ( - PROJECT_DEFAULT_CONFIG, PROJECT_SCHEMA, Model, Source, + BuildStage, Model, PipelineConfig, ProjectConfig, ProjectLayout, Source, + TreeConfig, TreeLayout, ) from datumaro.components.dataset import DEFAULT_FORMAT, Dataset, IDataset -from datumaro.components.dataset_filter import ( - XPathAnnotationsFilter, XPathDatasetFilter, -) from datumaro.components.environment import Environment from datumaro.components.errors import ( - MultipleFormatsMatchError, NoMatchingFormatsError, UnknownFormatError, + DatasetMergeError, EmptyCommitError, EmptyPipelineError, + ForeignChangesError, InvalidStageError, MigrationError, + MismatchingObjectError, MissingObjectError, MissingPipelineHeadError, + MultiplePipelineHeadsError, OldProjectError, PathOutsideSourceError, + ProjectAlreadyExists, ProjectNotFoundError, ReadonlyDatasetError, + ReadonlyProjectError, SourceExistsError, SourceUrlInsideProjectError, + UnexpectedUrlError, UnknownRefError, UnknownSourceError, UnknownStageError, + UnknownTargetError, UnsavedChangesError, VcsError, +) +from datumaro.components.launcher import Launcher +from datumaro.util import find, parse_str_enum_value +from datumaro.util.log_utils import catch_logs, logging_disabled +from datumaro.util.os_util import ( + copytree, generate_next_name, is_subpath, make_file_name, rmfile, rmtree, ) -from datumaro.components.extractor import DEFAULT_SUBSET_NAME, Extractor -from datumaro.components.launcher import ModelTransform -from datumaro.components.operations import ExactMerge +from datumaro.util.scope import on_error_do, scope_add, scoped -class ProjectDataset(IDataset): - class Subset(Extractor): - def __init__(self, parent, name): - super().__init__(subsets=[name]) - self.parent = parent - self.name = name or DEFAULT_SUBSET_NAME - self.items = OrderedDict() +class ProjectSourceDataset(IDataset): + def __init__(self, path: str, tree: 'Tree', source: str, + readonly: bool = False): + config = tree.sources[source] - def __iter__(self): - yield from self.items.values() + if config.path: + path = osp.join(path, config.path) - def __len__(self): - return len(self.items) + self.__dict__['_dataset'] = Dataset.import_from(path, + env=tree.env, format=config.format, **config.options) - def categories(self): - return self.parent.categories() + self.__dict__['_config'] = config + self.__dict__['_readonly'] = readonly + self.__dict__['name'] = source - def get(self, id, subset=None): - subset = subset or self.name - assert subset == self.name, '%s != %s' % (subset, self.name) - return super().get(id, subset) + def save(self, save_dir=None, **kwargs): + if save_dir is None and self.readonly: + raise ReadonlyDatasetError() + self._dataset.save(save_dir, **kwargs) - def __init__(self, project): - super().__init__() + @property + def readonly(self): + return self._readonly or not self.is_bound + + @property + def config(self): + return self._config + + def __getattr__(self, name): + return getattr(self._dataset, name) + + def __setattr__(self, name, value): + return setattr(self._dataset, name, value) + def __iter__(self): + yield from self._dataset + + def __len__(self): + return len(self._dataset) + + def subsets(self): + return self._dataset.subsets() + + def get_subset(self, name): + return self._dataset.get_subset(name) + + def categories(self): + return self._dataset.categories() + + def get(self, id, subset=None): + return self._dataset.get(id, subset) + + +class IgnoreMode(Enum): + rewrite = auto() + append = auto() + remove = auto() + +def _update_ignore_file(paths: Union[str, List[str]], repo_root: str, + filepath: str, mode: Union[None, str, IgnoreMode] = None): + def _make_ignored_path(path): + path = osp.join(repo_root, osp.normpath(path)) + assert is_subpath(path, base=repo_root) + + # Prepend the '/' to match only direct childs. + # Otherwise the rule can be in any path part. + return '/' + osp.relpath(path, repo_root).replace('\\', '/') + + header = '# The file is autogenerated by Datumaro' + + mode = parse_str_enum_value(mode, IgnoreMode, IgnoreMode.append) + + if isinstance(paths, str): + paths = [paths] + paths = {osp.join(repo_root, osp.normpath(p)): _make_ignored_path(p) + for p in paths} + + openmode = 'r+' + if not osp.isfile(filepath): + openmode = 'w+' # r+ cannot create, w truncates + with open(filepath, openmode) as f: + lines = [] + if mode in {IgnoreMode.append, IgnoreMode.remove}: + for line in f: + lines.append(line.strip()) + f.seek(0) + + new_lines = [] + for line in lines: + if not line or line.startswith('#'): + new_lines.append(line) + continue + + line_path = osp.join(repo_root, + osp.normpath(line.split('#', maxsplit=1)[0]) \ + .replace('\\', '/').lstrip('/')) + + if mode == IgnoreMode.append: + if line_path in paths: + paths.pop(line_path) + new_lines.append(line) + elif mode == IgnoreMode.remove: + if line_path not in paths: + new_lines.append(line) + + if mode in { IgnoreMode.rewrite, IgnoreMode.append }: + new_lines.extend(paths.values()) + + if not new_lines or new_lines[0] != header: + print(header, file=f) + for line in new_lines: + print(line, file=f) + f.truncate() + +class CrudProxy: + @property + def _data(self): + raise NotImplementedError() + + def __len__(self): + return len(self._data) + + def __getitem__(self, name): + return self._data[name] + + def get(self, name, default=None): + return self._data.get(name, default) + + def __iter__(self): + return iter(self._data.keys()) + + def items(self): + return iter(self._data.items()) + + def __contains__(self, name): + return name in self._data + +class _DataSourceBase(CrudProxy): + def __init__(self, project, config_field): self._project = project - self._env = project.env - config = self.config - env = self.env + self._field = config_field - sources = {} - for s_name, source in config.sources.items(): - s_format = source.format + @property + def _data(self): + return self._project.config[self._field] - url = source.url - if not source.url: - url = osp.join(config.project_dir, config.sources_dir, s_name) + def add(self, name, value): + if name in self: + raise SourceExistsError(name) - if s_format: - sources[s_name] = Dataset.import_from(url, - format=s_format, env=env, **source.options) + return self._data.set(name, value) + + def remove(self, name): + self._data.remove(name) + +class ProjectSources(_DataSourceBase): + def __init__(self, project): + super().__init__(project, 'sources') + + def __getitem__(self, name): + try: + return super().__getitem__(name) + except KeyError as e: + raise KeyError("Unknown source '%s'" % name) from e + + +class BuildStageType(Enum): + source = auto() + project = auto() + transform = auto() + filter = auto() + convert = auto() + inference = auto() + +class Pipeline: + @staticmethod + def _create_graph(config: PipelineConfig): + graph = nx.DiGraph() + for entry in config: + target_name = entry['name'] + parents = entry['parents'] + target = BuildStage(entry['config']) + + graph.add_node(target_name, config=target) + for prev_stage in parents: + graph.add_edge(prev_stage, target_name) + + return graph + + def __init__(self, config: PipelineConfig = None): + self._head = None + + if config is not None: + self._graph = self._create_craph(config) + if not self.head: + raise MissingPipelineHeadError() + else: + self._graph = nx.DiGraph() + + def __getattr__(self, key): + return getattr(self._graph, key) + + @staticmethod + def _find_head_node(graph) -> Optional[str]: + head = None + for node in graph.nodes: + if graph.out_degree(node) == 0: + if head is not None: + raise MultiplePipelineHeadsError( + "A pipeline can have only one " \ + "main target, but it has at least 2: %s, %s" % \ + (head, node)) + head = node + return head + + @property + def head(self) -> str: + if self._head is None: + self._head = self._find_head_node(self._graph) + return self._head + + @property + def head_node(self): + return self._graph.nodes[self.head] + + @staticmethod + def _serialize(graph) -> PipelineConfig: + serialized = PipelineConfig() + for node_name, node in graph.nodes.items(): + serialized.nodes.append({ + 'name': node_name, + 'parents': list(graph.predecessors(node_name)), + 'config': dict(node['config']), + }) + return serialized + + @staticmethod + def _get_subgraph(graph, target): + """ + Returns a subgraph with all the target dependencies and + the target itself. + """ + return graph.subgraph(nx.ancestors(graph, target) | {target}) + + def get_slice(self, target) -> 'Pipeline': + pipeline = Pipeline() + pipeline._graph = self._get_subgraph(self._graph, target).copy() + return pipeline + +class ProjectBuilder: + def __init__(self, project: 'Project', tree: 'Tree'): + self._project = project + self._tree = tree + + def make_dataset(self, pipeline: Pipeline) -> IDataset: + dataset = self._get_resulting_dataset(pipeline) + + # TODO: May be need to save and load, because it can modify dataset, + # unless we work with the internal format. For example, it can + # add format-specific attributes. It should be needed as soon + # format converting stages (export, convert, load) are allowed. + # + # TODO: If the target was rebuilt from sources, it may require saving + # and hashing, so the resulting hash could be compared with the saved + # one in the pipeline. This is needed to make sure the reproduced + # version of the dataset is correct. Currently we only rely on the + # initial source version check, which can be not enough if stages + # produce different result (because of the library changes etc). + # + # save_in_cache(project, pipeline) # update and check hash in config! + # dataset = load_dataset(project, pipeline) + + return dataset + + def _run_pipeline(self, pipeline: Pipeline): + self._validate_pipeline(pipeline) + + missing_sources, wd_hashes = self._find_missing_sources(pipeline) + for source_name in missing_sources: + source = self._tree.sources[source_name] + + if wd_hashes.get(source_name): + raise ForeignChangesError("Local source '%s' data does not " + "match any previous source revision. Probably, the source " + "was modified outside Datumaro. You can restore the " + "latest source revision with 'checkout' command." % \ + source_name) + + if self._project.readonly: + # Source re-downloading is prohibited in readonly projects + # because it can seriously hurt free storage space. It must + # be run manually, so that the user could know about this. + log.info("Skipping re-downloading missing source '%s', " + "because the project is read-only. Automatic downloading " + "is disabled in read-only projects.", source_name) + continue + + # TODO: check if we can avoid computing source hash in some cases + assert source.hash, source_name + with self._project._make_tmp_dir() as tmp_dir: + obj_hash, _, _ = \ + self._project._download_source(source.url, tmp_dir) + + if source.hash != obj_hash: + raise MismatchingObjectError( + "Downloaded source '%s' data is different " \ + "from what is saved in the build pipeline: " + "'%s' vs '%s'" % (source_name, obj_hash, source.hash)) + + return self._init_pipeline(pipeline, working_dir_hashes=wd_hashes) + + def _get_resulting_dataset(self, pipeline): + graph, head = self._run_pipeline(pipeline) + return graph.nodes[head]['dataset'] + + def _init_pipeline(self, pipeline: Pipeline, working_dir_hashes=None): + """ + Initializes datasets in the pipeline nodes. Currently, only the head + node will have a dataset on exit, so no extra memory is wasted + for the intermediate nodes. + """ + + def _join_parent_datasets(force=False): + parents = { p: graph.nodes[p] + for p in graph.predecessors(stage_name) } + + if 1 < len(parents) or force: + try: + dataset = Dataset.from_extractors( + *(p['dataset'] for p in parents.values()), + env=self._tree.env) + except DatasetMergeError as e: + e.sources = set(parents) + raise e else: - sources[s_name] = Project.load(url).make_dataset() - self._sources = sources - - own_source = None - own_source_dir = osp.join(config.project_dir, config.dataset_dir) - if config.project_dir and osp.isdir(own_source_dir): - own_source = Dataset.load(own_source_dir) - - # merge categories - # TODO: implement properly with merging and annotations remapping - categories = ExactMerge.merge_categories(s.categories() - for s in self._sources.values()) - # ovewrite with own categories - if own_source is not None and (not categories or len(own_source) != 0): - categories.update(own_source.categories()) - self._categories = categories - - # merge items - subsets = {} - for source_name, source in self._sources.items(): - log.debug("Loading '%s' source contents..." % source_name) - for item in source: - existing_item = subsets.setdefault( - item.subset, self.Subset(self, item.subset)). \ - items.get(item.id) - if existing_item is not None: - path = existing_item.path - if item.path != path: - path = None # NOTE: move to our own dataset - item = ExactMerge.merge_items(existing_item, item, path=path) - else: - s_config = config.sources[source_name] - if s_config and s_config.format: - # NOTE: consider imported sources as our own dataset - path = None + dataset = next(iter(parents.values()))['dataset'] + + # clear fully utilized datasets to release memory + for p_name, p in parents.items(): + p['_use_count'] = p.get('_use_count', 0) + 1 + + if p_name != head and p['_use_count'] == graph.out_degree(p_name): + p.pop('dataset') + + return dataset + + if working_dir_hashes is None: + working_dir_hashes = {} + def _try_load_from_cache(stage_name: str, stage_config: BuildStage) \ + -> Dataset: + # Check if we can restore this stage from the cache or + # from the working directory. + # + # If we have a hash, we have executed this stage already + # and can have a cache entry or, + # if this is the last stage of a target in the working tree, + # we can use data from the working directory. + stage_hash = stage_config.hash + + data_dir = None + cached = False + + target = ProjectBuildTargets.strip_target_name(stage_name) + if self._tree.is_working_tree and target in self._tree.sources: + data_dir = self._project.source_data_dir(target) + + wd_hash = working_dir_hashes.get(target) + if not wd_hash: + if osp.isdir(data_dir): + wd_hash = self._project.compute_source_hash(data_dir) + working_dir_hashes[target] = wd_hash else: - path = [source_name] + (item.path or []) - item = item.wrap(path=path) + log.debug("Build: skipping checking working dir '%s', " + "because it does not exist", data_dir) + data_dir = None + + if stage_hash != wd_hash: + log.debug("Build: skipping loading stage '%s' from " + "working dir '%s', because hashes does not match", + stage_name, data_dir) + data_dir = None + + if not data_dir: + if self._project._is_cached(stage_hash): + data_dir = self._project.cache_path(stage_hash) + cached = True + elif self._project._can_retrieve_from_vcs_cache(stage_hash): + data_dir = self._project._materialize_obj(stage_hash) + cached = True + + if not data_dir or not osp.isdir(data_dir): + log.debug("Build: skipping loading stage '%s' from " + "cache obj '%s', because it is not available", + stage_name, stage_hash) + return None + + if data_dir: + assert osp.isdir(data_dir), data_dir + log.debug("Build: loading stage '%s' from '%s'", + stage_name, data_dir) + return ProjectSourceDataset(data_dir, self._tree, target, + readonly=cached or self._project.readonly) + + return None + + # Pipeline is assumed to be validated already + graph = pipeline._graph + head = pipeline.head + + # traverse the graph and initialize nodes from sources to the head + to_visit = [head] + while to_visit: + stage_name = to_visit.pop() + stage = graph.nodes[stage_name] + stage_config = stage['config'] + stage_type = BuildStageType[stage_config.type] + + assert stage.get('dataset') is None + + stage_hash = stage_config.hash + if stage_hash: + dataset = _try_load_from_cache(stage_name, stage_config) + if dataset is not None: + stage['dataset'] = dataset + continue + + uninitialized_parents = [] + for p_name in graph.predecessors(stage_name): + parent = graph.nodes[p_name] + if parent.get('dataset') is None: + uninitialized_parents.append(p_name) + + if uninitialized_parents: + to_visit.append(stage_name) + to_visit.extend(uninitialized_parents) + continue + + if stage_type == BuildStageType.transform: + kind = stage_config.kind + try: + transform = self._tree.env.transforms[kind] + except KeyError as e: + raise UnknownStageError("Unknown transform '%s'" % kind) \ + from e + + dataset = _join_parent_datasets() + dataset = dataset.transform(transform, **stage_config.params) + + elif stage_type == BuildStageType.filter: + dataset = _join_parent_datasets() + dataset = dataset.filter(**stage_config.params) + + elif stage_type == BuildStageType.inference: + kind = stage_config.kind + model = self._project.make_model(kind) + + dataset = _join_parent_datasets() + dataset = dataset.run_model(model) + + elif stage_type == BuildStageType.source: + # Stages of type "Source" cannot have inputs, + # they are build tree inputs themselves + assert graph.in_degree(stage_name) == 0, stage_name + + # The only valid situation we get here is that it is a + # generated source: + # - No cache entry + # - No local dir data + source_name = ProjectBuildTargets.strip_target_name(stage_name) + source = self._tree.sources[source_name] + if not source.is_generated: + # Source is missing in the cache and the working tree, + # and cannot be retrieved from the VCS cache. + # It is assumed that all the missing sources were + # downloaded earlier. + raise MissingObjectError( + "Failed to initialize stage '%s': " + "object '%s' was not found in cache" % \ + (stage_name, stage_hash)) + + # Generated sources do not require a data directory, + # but they still can be bound to a directory + if self._tree.is_working_tree: + source_dir = self._project.source_data_dir(source_name) + else: + source_dir = None + dataset = ProjectSourceDataset(source_dir, self._tree, + source_name, + readonly=not source_dir or self._project.readonly) + + elif stage_type == BuildStageType.project: + dataset = _join_parent_datasets(force=True) - subsets[item.subset].items[item.id] = item + elif stage_type == BuildStageType.convert: + dataset = _join_parent_datasets() + + else: + raise UnknownStageError("Unexpected stage type '%s'" % \ + stage_type) - # override with our items, fallback to existing images - if own_source is not None: - log.debug("Loading own dataset...") - for item in own_source: - existing_item = subsets.setdefault( - item.subset, self.Subset(self, item.subset)). \ - items.get(item.id) - if existing_item is not None: - item = item.wrap(path=None, - image=ExactMerge.merge_images(existing_item, item)) + stage['dataset'] = dataset - subsets[item.subset].items[item.id] = item + return graph, head - self._subsets = subsets + @staticmethod + def _validate_pipeline(pipeline: Pipeline): + graph = pipeline._graph + if len(graph) == 0: + raise EmptyPipelineError() + + head = pipeline.head + if not head: + raise MissingPipelineHeadError() + + for stage_name, stage in graph.nodes.items(): + stage_type = BuildStageType[stage['config'].type] + + if graph.in_degree(stage_name) == 0: + if stage_type != BuildStageType.source: + raise InvalidStageError( + "Stage '%s' of type '%s' must have inputs" % + (stage_name, stage_type.name)) + else: + if stage_type == BuildStageType.source: + raise InvalidStageError( + "Stage '%s' of type '%s' can't have inputs" % + (stage_name, stage_type.name)) + + if graph.out_degree(stage_name) == 0: + if stage_name != head: + raise InvalidStageError( + "Stage '%s' of type '%s' has no outputs, " + "but is not the head stage" % + (stage_name, stage_type.name)) + + def _find_missing_sources(self, pipeline: Pipeline): + work_dir_hashes = {} + + def _can_retrieve(stage_name: str, stage_config: BuildStage): + obj_hash = stage_config.hash + + source_name = ProjectBuildTargets.strip_target_name(stage_name) + if self._tree.is_working_tree and source_name in self._tree.sources: + data_dir = self._project.source_data_dir(source_name) + + wd_hash = work_dir_hashes.get(source_name) + if not wd_hash and osp.isdir(data_dir): + wd_hash = self._project.compute_source_hash( + self._project.source_data_dir(source_name)) + work_dir_hashes[source_name] = wd_hash + + if obj_hash and obj_hash == wd_hash: + return True + + if obj_hash and self._project.is_obj_cached(obj_hash): + return True + + return False + + missing_sources = set() + checked_deps = set() + unchecked_deps = [pipeline.head] + while unchecked_deps: + stage_name = unchecked_deps.pop() + if stage_name in checked_deps: + continue + + stage_config = pipeline._graph.nodes[stage_name]['config'] + + if not _can_retrieve(stage_name, stage_config): + if pipeline._graph.in_degree(stage_name) == 0: + assert stage_config.type == 'source', stage_config.type + source_name = \ + self._tree.build_targets.strip_target_name(stage_name) + source = self._tree.sources[source_name] + if not source.is_generated: + missing_sources.add(source_name) + else: + for p in pipeline._graph.predecessors(stage_name): + if p not in checked_deps: + unchecked_deps.append(p) + continue - self._length = None + checked_deps.add(stage_name) + return missing_sources, work_dir_hashes - def iterate_own(self): - return self.select(lambda item: not item.path) +class ProjectBuildTargets(CrudProxy): + MAIN_TARGET = 'project' + BASE_STAGE = 'root' - def __iter__(self): - for subset in self._subsets.values(): - yield from subset + def __init__(self, tree: 'Tree'): + self._tree = tree - def get_subset(self, name): - return self._subsets[name] + @property + def _data(self): + data = self._tree.config.build_targets + + if self.MAIN_TARGET not in data: + data[self.MAIN_TARGET] = { + 'stages': [ + BuildStage({ + 'name': self.BASE_STAGE, + 'type': BuildStageType.project.name, + }), + ] + } + + for source in self._tree.sources: + if source not in data: + data[source] = { + 'stages': [ + BuildStage({ + 'name': self.BASE_STAGE, + 'type': BuildStageType.source.name, + }), + ] + } + + return data + + def __contains__(self, key): + if '.' in key: + target, stage = self.split_target_name(key) + return target in self._data and \ + self._data[target].find_stage(stage) is not None + return key in self._data + + def add_target(self, name): + return self._data.set(name, { + 'stages': [ + BuildStage({ + 'name': self.BASE_STAGE, + 'type': BuildStageType.source.name, + }), + ] + }) + + def add_stage(self, target, value, prev=None, name=None) -> str: + target_name = target + target_stage_name = None + if '.' in target: + target_name, target_stage_name = self.split_target_name(target) + + if prev is None: + prev = target_stage_name + + target = self._data[target_name] + + if prev: + prev_stage = find(enumerate(target.stages), + lambda e: e[1].name == prev) + if prev_stage is None: + raise KeyError("Can't find stage '%s'" % prev) + prev_stage = prev_stage[0] + else: + prev_stage = len(target.stages) - 1 - def subsets(self): - return self._subsets + name = value.get('name') or name + if not name: + name = generate_next_name((s.name for s in target.stages), + 'stage', sep='-', default='1') + else: + if target.find_stage(name): + raise VcsError("Stage '%s' already exists" % name) + value['name'] = name + + value = BuildStage(value) + assert value.type in BuildStageType.__members__ + target.stages.insert(prev_stage + 1, value) + + return self.make_target_name(target_name, name) + + def remove_target(self, name: str): + assert name != self.MAIN_TARGET, "Can't remove the main target" + self._data.remove(name) + + def remove_stage(self, target: str, name: str): + assert name not in {self.BASE_STAGE}, "Can't remove a default stage" + + target = self._data[target] + idx = find(enumerate(target.stages), lambda e: e[1].name == name) + if idx is None: + raise KeyError("Can't find stage '%s'" % name) + target.stages.remove(idx) + + def add_transform_stage(self, target: str, transform: str, + params: Optional[Dict] = None, name: Optional[str] = None): + if not transform in self._tree.env.transforms: + raise KeyError("Unknown transform '%s'" % transform) + + return self.add_stage(target, { + 'type': BuildStageType.transform.name, + 'kind': transform, + 'params': params or {}, + }, name=name) + + def add_inference_stage(self, target: str, model: str, + params: Optional[Dict] = None, name: Optional[str] = None): + if not model in self._tree._project.models: + raise KeyError("Unknown model '%s'" % model) + + return self.add_stage(target, { + 'type': BuildStageType.inference.name, + 'kind': model, + 'params': params or {}, + }, name=name) + + def add_filter_stage(self, target: str, expr: str, + params: Optional[Dict] = None, name: Optional[str] = None): + params = params or {} + params['expr'] = expr + return self.add_stage(target, { + 'type': BuildStageType.filter.name, + 'params': params, + }, name=name) + + def add_convert_stage(self, target: str, format: str, + params: Optional[Dict] = None, name: Optional[str] = None): + if not self._tree.env.is_format_known(format): + raise KeyError("Unknown format '%s'" % format) + + return self.add_stage(target, { + 'type': BuildStageType.convert.name, + 'kind': format, + 'params': params or {}, + }, name=name) - def categories(self): - return self._categories + @staticmethod + def make_target_name(target: str, stage: Optional[str] = None) -> str: + if stage: + return '%s.%s' % (target, stage) + return target - def __len__(self): - return sum(len(s) for s in self._subsets.values()) - - def get(self, id, subset=None, path=None): # pylint: disable=arguments-differ - if path: - source = path[0] - return self._sources[source].get(id=id, subset=subset) - return self._subsets.get(subset, {}).get(id) - - def put(self, item, id=None, subset=None, path=None): - if path is None: - path = item.path - - if path: - source = path[0] - # TODO: reverse remapping - self._sources[source].put(item, id=id, subset=subset, path=path[1:]) - - if id is None: - id = item.id - if subset is None: - subset = item.subset - - item = item.wrap(path=path) - if subset not in self._subsets: - self._subsets[subset] = self.Subset(self, subset) - self._subsets[subset].items[id] = item - self._length = None - - return item - - def save(self, save_dir=None, merge=False, recursive=True, - save_images=False): - if save_dir is None: - assert self.config.project_dir - save_dir = self.config.project_dir - project = self._project + @classmethod + def split_target_name(cls, name: str) -> Tuple[str, str]: + if '.' in name: + target, stage = name.split('.', maxsplit=1) + if not target: + raise ValueError("Wrong build target name '%s': " + "a name can't be empty" % name) + if not stage: + raise ValueError("Wrong build target name '%s': " + "expected stage name after the separator" % name) else: - merge = True + target = name + stage = cls.BASE_STAGE + return target, stage + + @classmethod + def strip_target_name(cls, name: str) -> str: + return cls.split_target_name(name)[0] + + def _make_full_pipeline(self) -> Pipeline: + pipeline = Pipeline() + graph = pipeline._graph + + for target_name, target in self.items(): + if target_name == self.MAIN_TARGET: + # main target combines all the others + prev_stages = [self.make_target_name(n, t.head.name) + for n, t in self.items() if n != self.MAIN_TARGET] + else: + prev_stages = [self.make_target_name(t, self[t].head.name) + for t in target.parents] + + for stage in target.stages: + stage_name = self.make_target_name(target_name, stage['name']) + + graph.add_node(stage_name, config=stage) + + for prev_stage in prev_stages: + graph.add_edge(prev_stage, stage_name) + prev_stages = [stage_name] - if merge: - project = Project(Config(self.config)) - project.config.remove('sources') + return pipeline - save_dir = osp.abspath(save_dir) - dataset_save_dir = osp.join(save_dir, project.config.dataset_dir) + def make_pipeline(self, target) -> Pipeline: + if not target in self: + raise UnknownTargetError(target) - converter_kwargs = { - 'save_images': save_images, - } + # a subgraph with all the target dependencies + if '.' not in target: + target = self.make_target_name(target, self[target].head.name) - save_dir_existed = osp.exists(save_dir) + return self._make_full_pipeline().get_slice(target) + +class GitWrapper: + @staticmethod + def module(): try: - os.makedirs(save_dir, exist_ok=True) - os.makedirs(dataset_save_dir, exist_ok=True) + import git + return git + except ModuleNotFoundError as e: + raise ModuleNotFoundError("Can't import the 'git' package. " + "Make sure GitPython is installed, or install it with " + "'pip install datumaro[default]'." + ) from e + + def _git_dir(self): + return osp.join(self._project_dir, '.git') + + def __init__(self, project_dir, repo=None): + self._project_dir = project_dir + self.repo = repo + + if repo is None and \ + osp.isdir(project_dir) and osp.isdir(self._git_dir()): + self.repo = self.module().Repo(project_dir) - if merge: - # merge and save the resulting dataset - self.env.converters.get(DEFAULT_FORMAT).convert( - self, dataset_save_dir, **converter_kwargs) + @property + def initialized(self): + return self.repo is not None + + def init(self): + if self.initialized: + return + + repo = self.module().Repo.init(path=self._project_dir) + repo.config_writer() \ + .set_value("user", "name", "User") \ + .set_value("user", "email", "<>") \ + .release() + + # GitPython's init produces an incomplete repo, which becomes normal + # only after a first commit. Unless the commit is done, some + # GitPython's functions will throw useless errors. + # Call "git init" directly to have the desired behaviour. + repo.git.init() + + self.repo = repo + + def close(self): + if self.repo: + self.repo.close() + self.repo = None + + def __del__(self): + with suppress(Exception): + self.close() + + def checkout(self, ref: str = None, dst_dir=None, clean=False, force=False): + # If user wants to navigate to a head, we need to supply its object + # insted of just a string. Otherwise, we'll get a detached head. + try: + ref_obj = self.repo.heads[ref] + except IndexError: + ref_obj = ref + + commit = self.repo.commit(ref) + tree = commit.tree + + if not dst_dir: + dst_dir = self._project_dir + + repo_dir = osp.abspath(self._project_dir) + dst_dir = osp.abspath(dst_dir) + assert is_subpath(dst_dir, base=repo_dir) + + if not force: + statuses = self.status(tree, base_dir=dst_dir) + + # Only modified files produce conflicts in checkout + dst_rpath = osp.relpath(dst_dir, repo_dir) + conflicts = [osp.join(dst_rpath, p) + for p, s in statuses.items() if s == 'M'] + if conflicts: + raise UnsavedChangesError(conflicts) + + self.repo.head.ref = ref_obj + self.repo.head.reset(working_tree=False) + + if clean: + rmtree(dst_dir) + + self.write_tree(tree, dst_dir) + + def add(self, paths, base=None): + """ + Adds paths to index. + Paths can be truncated relatively to base. + """ + + path_rewriter = None + if base: + base = osp.abspath(base) + repo_root = osp.abspath(self._project_dir) + assert is_subpath(base, base=repo_root), \ + "Base path should be inside of the repo" + base = osp.relpath(base, repo_root) + path_rewriter = lambda entry: osp.relpath(entry.path, base) \ + .replace('\\', '/') + + if isinstance(paths, str): + paths = [paths] + + # A workaround for path_rewriter incompatibility + # with directory paths expansion + paths_to_add = [] + for path in paths: + if not osp.isdir(path): + paths_to_add.append(path) + continue + + for d, _, filenames in os.walk(path): + for fn in filenames: + paths_to_add.append(osp.join(d, fn)) + + self.repo.index.add(paths_to_add, path_rewriter=path_rewriter) + + def commit(self, message) -> str: + """ + Creates a new revision from index. + Returns: new revision hash. + """ + return self.repo.index.commit(message).hexsha + + GitTree = NewType('GitTree', object) + GitStatus = NewType('GitStatus', str) + + def status(self, paths: Union[str, GitTree, Iterable[str]] = None, + base_dir: str = None) -> Dict[str, GitStatus]: + """ + Compares working directory and index. + + Parameters: + - paths - an iterable of paths to compare, a git.Tree, or None. + When None, uses all the paths from HEAD. + - base_dir - a base path for paths. Paths will be prepended by this. + When None or '', uses repo root. Can be useful, if index contains + displaced paths, which needs to be mapped on real paths. + + The statuses are: + - "A" for added paths + - "D" for deleted paths + - "R" for renamed paths + - "M" for paths with modified data + - "T" for changed in the type paths + + Returns: { abspath(base_dir + path): status } + """ + + if paths is None or isinstance(paths, self.module().objects.tree.Tree): + if paths is None: + tree = self.repo.head.commit.tree + else: + tree = paths + paths = (obj.path for obj in tree.traverse() if obj.type == 'blob') + elif isinstance(paths, str): + paths = [paths] + + if not base_dir: + base_dir = self._project_dir + + repo_dir = osp.abspath(self._project_dir) + base_dir = osp.abspath(base_dir) + assert is_subpath(base_dir, base=repo_dir) + + statuses = {} + for obj_path in paths: + file_path = osp.join(base_dir, obj_path) + + index_entry = self.repo.index.entries.get((obj_path, 0), None) + file_exists = osp.isfile(file_path) + if not file_exists and index_entry: + status = 'D' + elif file_exists and not index_entry: + status = 'A' + elif file_exists and index_entry: + # '--ignore-cr-at-eol' doesn't affect '--name-status' + # so we can't really obtain 'T' + status = self.repo.git.diff('--ignore-cr-at-eol', + index_entry.hexsha, file_path) + if status: + status = 'M' + assert status in {'', 'M', 'T'}, status else: - if recursive: - # children items should already be updated - # so we just save them recursively - for source in self._sources.values(): - if isinstance(source, ProjectDataset): - source.save(**converter_kwargs) - - self.env.converters.get(DEFAULT_FORMAT).convert( - self.iterate_own(), dataset_save_dir, **converter_kwargs) - - project.save(save_dir) - except BaseException: - if not save_dir_existed and osp.isdir(save_dir): - shutil.rmtree(save_dir, ignore_errors=True) - raise + status = '' # ignore missing paths + + if status: + statuses[obj_path] = status + + return statuses + + def is_ref(self, rev): + try: + self.repo.commit(rev) + return True + except (ValueError, self.module().exc.BadName): + return False + + def has_commits(self): + return self.is_ref('HEAD') + + def get_tree(self, ref): + return self.repo.tree(ref) + + def write_tree(self, tree, base_path: str, + include_files: Optional[List[str]] = None): + os.makedirs(base_path, exist_ok=True) + + for obj in tree.traverse(visit_once=True): + if include_files and obj.path not in include_files: + continue + + path = osp.join(base_path, obj.path) + os.makedirs(osp.dirname(path), exist_ok=True) + if obj.type == 'blob': + with open(path, 'wb') as f: + obj.stream_data(f) + elif obj.type == 'tree': + pass + else: + raise ValueError("Unexpected object type in a " + "git tree: %s (%s)" % (obj.type, obj.hexsha)) @property - def config(self): - return self._project.config + def head(self) -> str: + return self.repo.head.commit.hexsha @property - def env(self): - return self._project.env + def branch(self) -> str: + if self.repo.head.is_detached: + return None + return self.repo.active_branch + + def rev_parse(self, ref: str) -> Tuple[str, str]: + """ + Expands named refs and tags. + + Returns: object type, object hash + """ + obj = self.repo.rev_parse(ref) + return obj.type, obj.hexsha + + def ignore(self, paths: Union[str, List[str]], + mode: Union[None, str, IgnoreMode] = None, + gitignore: Optional[str] = None): + if not gitignore: + gitignore = '.gitignore' + repo_root = self._project_dir + gitignore = osp.abspath(osp.join(repo_root, gitignore)) + assert is_subpath(gitignore, base=repo_root), gitignore + + _update_ignore_file(paths, repo_root=repo_root, + mode=mode, filepath=gitignore) + + HASH_LEN = 40 + + @classmethod + def is_hash(cls, s: str) -> bool: + return len(s) == cls.HASH_LEN + + def log(self, depth=10) -> List[Tuple[Any, int]]: + """ + Returns: a list of (commit, index) pairs + """ + + commits = [] + + if not self.has_commits(): + return commits + + for commit in zip(self.repo.iter_commits(rev='HEAD'), range(depth)): + commits.append(commit) + return commits + +class DvcWrapper: + @staticmethod + def module(): + try: + import dvc + import dvc.env + import dvc.main + import dvc.repo + return dvc + except ModuleNotFoundError as e: + raise ModuleNotFoundError("Can't import the 'dvc' package. " + "Make sure DVC is installed, or install it with " + "'pip install datumaro[default]'." + ) from e + + def _dvc_dir(self): + return osp.join(self._project_dir, '.dvc') + + class DvcError(Exception): + pass + + def __init__(self, project_dir): + self._project_dir = project_dir + self.repo = None + + if osp.isdir(project_dir) and osp.isdir(self._dvc_dir()): + with logging_disabled(): + self.repo = self.module().repo.Repo(project_dir) + + @property + def initialized(self): + return self.repo is not None + + def init(self): + if self.initialized: + return + + with logging_disabled(): + self.repo = self.module().repo.Repo.init(self._project_dir) + + repo_dir = osp.join(self._project_dir, '.dvc') + _update_ignore_file([osp.join(repo_dir, 'plots')], + filepath=osp.join(repo_dir, '.gitignore'), + repo_root=repo_dir + ) + + def close(self): + if self.repo: + self.repo.close() + self.repo = None + + def __del__(self): + with suppress(Exception): + self.close() + + def checkout(self, targets=None): + args = ['checkout'] + if targets: + if isinstance(targets, str): + args.append(targets) + else: + args.extend(targets) + self._exec(args) + + def add(self, paths, dvc_path=None, no_commit=False, allow_external=False): + args = ['add'] + if dvc_path: + args.append('--file') + args.append(dvc_path) + os.makedirs(osp.dirname(dvc_path), exist_ok=True) + if no_commit: + args.append('--no-commit') + if allow_external: + args.append('--external') + if paths: + if isinstance(paths, str): + args.append(paths) + else: + args.extend(paths) + self._exec(args) + + def _exec(self, args, hide_output=True, answer_on_input='y'): + args = ['--cd', self._project_dir] + args + + # Avoid calling an extra process. Improves call performance and + # removes an extra console window on Windows. + os.environ[self.module().env.DVC_NO_ANALYTICS] = '1' + + with ExitStack() as es: + es.callback(os.chdir, os.getcwd()) # restore cd after DVC + + if answer_on_input is not None: + def _input(*args): return answer_on_input + es.enter_context(unittest.mock.patch( + 'dvc.prompt.input', new=_input)) + + log.debug("Calling DVC main with args: %s", args) + + logs = es.enter_context(catch_logs('dvc')) + retcode = self.module().main.main(args) + + logs = logs.getvalue() + if retcode != 0: + raise self.DvcError(logs) + if not hide_output: + print(logs) + return logs + + def is_cached(self, obj_hash): + path = self.obj_path(obj_hash) + if not osp.isfile(path): + return False + + if obj_hash.endswith(self.DIR_HASH_SUFFIX): + with open(path) as f: + objects = json.load(f) + for entry in objects: + if not osp.isfile(self.obj_path(entry['md5'])): + return False + + return True + + def obj_path(self, obj_hash, root=None): + assert self.is_hash(obj_hash), obj_hash + if not root: + root = osp.join(self._project_dir, '.dvc', 'cache') + return osp.join(root, obj_hash[:2], obj_hash[2:]) + + def ignore(self, paths: Union[str, List[str]], + mode: Union[None, str, IgnoreMode] = None, + dvcignore: Optional[str] = None): + if not dvcignore: + dvcignore = '.dvcignore' + repo_root = self._project_dir + dvcignore = osp.abspath(osp.join(repo_root, dvcignore)) + assert is_subpath(dvcignore, base=repo_root), dvcignore + + _update_ignore_file(paths, repo_root=repo_root, + mode=mode, filepath=dvcignore) + + # This ruamel parser is needed to preserve comments, + # order and form (if multiple forms allowed by the standard) + # of the entries in the file. It can be reused. + yaml_parser = yaml.YAML(typ='rt') + + @classmethod + def get_hash_from_dvcfile(cls, path) -> str: + with open(path) as f: + contents = cls.yaml_parser.load(f) + return contents['outs'][0]['md5'] + + FILE_HASH_LEN = 32 + DIR_HASH_SUFFIX = '.dir' + DIR_HASH_LEN = FILE_HASH_LEN + len(DIR_HASH_SUFFIX) + + @classmethod + def is_file_hash(cls, s: str) -> bool: + return len(s) == cls.FILE_HASH_LEN + + @classmethod + def is_dir_hash(cls, s: str) -> bool: + return len(s) == cls.DIR_HASH_LEN and s.endswith(cls.DIR_HASH_SUFFIX) + + @classmethod + def is_hash(cls, s: str) -> bool: + return cls.is_file_hash(s) or cls.is_dir_hash(s) + + def write_obj(self, obj_hash, dst_dir, allow_links=True): + def _copy_obj(src, dst, link=False): + os.makedirs(osp.dirname(dst), exist_ok=True) + if link: + os.link(src, dst) + else: + shutil.copy(src, dst, follow_symlinks=True) + + src = self.obj_path(obj_hash) + if osp.isfile(src): + _copy_obj(src, dst_dir, link=allow_links) + return + + src += self.DIR_HASH_SUFFIX + if not osp.isfile(src): + raise UnknownRefError(obj_hash) + + with open(src) as f: + src_meta = json.load(f) + for entry in src_meta: + _copy_obj(self.obj_path(entry['md5']), + osp.join(dst_dir, entry['relpath']), link=allow_links) + + def remove_cache_obj(self, obj_hash: str): + src = self.obj_path(obj_hash) + if osp.isfile(src): + rmfile(src) + return + + src += self.DIR_HASH_SUFFIX + if not osp.isfile(src): + raise UnknownRefError(obj_hash) + + with open(src) as f: + src_meta = json.load(f) + for entry in src_meta: + entry_path = self.obj_path(entry['md5']) + if osp.isfile(entry_path): + rmfile(entry_path) + + rmfile(src) + +class Tree: + # can be: + # - attached to the work dir + # - attached to a revision + + def __init__(self, project: 'Project', + config: Union[None, Dict, Config, TreeConfig] = None, + rev: Union[None, 'Revision'] = None): + assert isinstance(project, Project) + assert not rev or project.is_ref(rev), rev + + if not isinstance(config, TreeConfig): + config = TreeConfig(config) + if config.format_version != 2: + raise ValueError("Unexpected tree config version '%s', expected 2" % + config.format_version) + self._config = config + + self._project = project + self._rev = rev + + self._sources = ProjectSources(self) + self._targets = ProjectBuildTargets(self) + + def save(self): + self.dump(self._config.config_path) + + def dump(self, path): + os.makedirs(osp.dirname(path), exist_ok=True) + self._config.dump(path) @property - def sources(self): + def sources(self) -> ProjectSources: return self._sources - def _save_branch_project(self, extractor, save_dir=None): - if not isinstance(extractor, Dataset): - extractor = Dataset.from_extractors( - extractor) # apply lazy transforms to avoid repeating traversals + @property + def build_targets(self) -> ProjectBuildTargets: + return self._targets - # NOTE: probably this function should be in the ViewModel layer - save_dir = osp.abspath(save_dir) - if save_dir: - dst_project = Project() - else: - if not self.config.project_dir: - raise ValueError("Either a save directory or a project " - "directory should be specified") - save_dir = self.config.project_dir + @property + def config(self) -> Config: + return self._config - dst_project = Project(Config(self.config)) - dst_project.config.remove('project_dir') - dst_project.config.remove('sources') - dst_project.config.project_name = osp.basename(save_dir) + @property + def env(self) -> Environment: + return self._project.env - dst_dataset = dst_project.make_dataset() - dst_dataset._categories = extractor.categories() - dst_dataset.update(extractor) + @property + def rev(self) -> Union[None, 'Revision']: + return self._rev - dst_dataset.save(save_dir=save_dir, merge=True) + def make_dataset(self, target: Optional[str] = None) -> Dataset: + if not target: + target = 'project' - def transform(self, method, *args, **kwargs): - if isinstance(method, str): - method = self.env.make_transform(method) + pipeline = self.build_targets.make_pipeline(target) + return ProjectBuilder(self._project, self).make_dataset(pipeline) - return method(self, *args, **kwargs) + @property + def is_working_tree(self) -> bool: + return not self._rev - def filter(self, expr: str, filter_annotations: bool = False, - remove_empty: bool = False) -> Dataset: - if filter_annotations: - return self.transform(XPathAnnotationsFilter, expr, remove_empty) - else: - return self.transform(XPathDatasetFilter, expr) + def source_data_dir(self, source) -> str: + if self.is_working_tree: + return self._project.source_data_dir(source) - def update(self, other): - for item in other: - self.put(item) - return self + obj_hash = self.build_targets[source].head.hash + return self._project.cache_path(obj_hash) - def select(self, pred): - class _DatasetFilter(Extractor): - def __init__(self, _): - super().__init__() - def __iter__(_): - return filter(pred, iter(self)) - def categories(_): - return self.categories() - - return self.transform(_DatasetFilter) - - def export(self, save_dir: str, format, **kwargs): - dataset = Dataset.from_extractors(self, env=self.env) - dataset.export(save_dir, format, **kwargs) - - def define_categories(self, categories): - assert not self._categories - self._categories = categories - - def transform_project(self, method, save_dir=None, **method_kwargs): - # NOTE: probably this function should be in the ViewModel layer - transformed = self.transform(method, **method_kwargs) - self._save_branch_project(transformed, save_dir=save_dir) - - def apply_model(self, model, save_dir=None, batch_size=1): - # NOTE: probably this function should be in the ViewModel layer - if isinstance(model, str): - model = self._project.make_executable_model(model) - - self.transform_project(ModelTransform, launcher=model, - save_dir=save_dir, batch_size=batch_size) - - def export_project(self, save_dir, converter, - filter_expr=None, filter_annotations=False, remove_empty=False): - # NOTE: probably this function should be in the ViewModel layer - dataset = self - if filter_expr: - dataset = dataset.filter(filter_expr, - filter_annotations=filter_annotations, - remove_empty=remove_empty) - - save_dir = osp.abspath(save_dir) - save_dir_existed = osp.exists(save_dir) - try: - os.makedirs(save_dir, exist_ok=True) - converter(dataset, save_dir) - except BaseException: - if not save_dir_existed: - shutil.rmtree(save_dir) - raise - - def filter_project(self, filter_expr, filter_annotations=False, - save_dir=None, remove_empty=False): - # NOTE: probably this function should be in the ViewModel layer - dataset = self - if filter_expr: - dataset = dataset.filter(filter_expr, - filter_annotations=filter_annotations, - remove_empty=remove_empty) - self._save_branch_project(dataset, save_dir=save_dir) + +class DiffStatus(Enum): + added = auto() + modified = auto() + removed = auto() + missing = auto() + foreign_modified = auto() + +Revision = NewType('Revision', str) # a commit hash or a named reference +ObjectId = NewType('ObjectId', str) # a commit or an object hash class Project: + @staticmethod + def find_project_dir(path: str) -> Optional[str]: + path = osp.abspath(path) + + if osp.basename(path) != ProjectLayout.aux_dir: + path = osp.join(path, ProjectLayout.aux_dir) + + if osp.isdir(path): + return path + + return None + + @staticmethod + @scoped + def migrate_from_v1_to_v2(src_dir: str, dst_dir: str, + skip_import_errors=False): + if not osp.isdir(src_dir): + raise FileNotFoundError("Source project is not found") + + if osp.exists(dst_dir): + raise FileExistsError("Output path already exists") + + src_dir = osp.abspath(src_dir) + dst_dir = osp.abspath(dst_dir) + if src_dir == dst_dir: + raise MigrationError("Source and destination paths are the same. " + "Project migration cannot be done inplace.") + + old_aux_dir = osp.join(src_dir, '.datumaro') + old_config = Config.parse(osp.join(old_aux_dir, 'config.yaml')) + if old_config.format_version != 1: + raise MigrationError("Failed to migrate project: " + "unexpected old version '%s'" % \ + old_config.format_version) + + on_error_do(rmtree, dst_dir, ignore_errors=True) + new_project = scope_add(Project.init(dst_dir)) + + new_wtree_dir = osp.join(new_project._aux_dir, + ProjectLayout.working_tree_dir) + os.makedirs(new_wtree_dir, exist_ok=True) + + old_plugins_dir = osp.join(old_aux_dir, 'plugins') + if osp.isdir(old_plugins_dir): + copytree(old_plugins_dir, + osp.join(new_project._aux_dir, ProjectLayout.plugins_dir)) + + old_models_dir = osp.join(old_aux_dir, 'models') + if osp.isdir(old_models_dir): + copytree(old_models_dir, + osp.join(new_project._aux_dir, ProjectLayout.models_dir)) + + new_project.env.load_plugins( + osp.join(new_project._aux_dir, ProjectLayout.plugins_dir)) + + new_tree_config = new_project.working_tree.config + new_local_config = new_project.config + + if 'models' in old_config: + for name, old_model in old_config.models.items(): + new_local_config.models[name] = Model({ + 'launcher': old_model['launcher'], + 'options': old_model['options'] + }) + + if 'sources' in old_config: + for name, old_source in old_config.sources.items(): + is_local = False + source_dir = osp.join(src_dir, 'sources', name) + url = osp.abspath(osp.join(source_dir, old_source['url'])) + rpath = None + if osp.exists(url): + if is_subpath(url, source_dir): + if url != source_dir: + rpath = osp.relpath(url, source_dir) + url = source_dir + is_local = True + elif osp.isfile(url): + url, rpath = osp.split(url) + elif not old_source['url']: + url = '' + + try: + source = new_project.import_source(name, + url=url, rpath=rpath, format=old_source['format'], + options=old_source['options']) + if is_local: + source.url = '' + + new_project.working_tree.make_dataset(name) + except Exception as e: + if not skip_import_errors: + raise MigrationError( + f"Failed to migrate the source '{name}'") from e + else: + log.warning(f"Failed to migrate the source '{name}'. " + "Try to add this source manually with " + "'datum add', once migration is finished. The " + "reason is: %s", e) + new_project.remove_source(name, + force=True, keep_data=False) + + old_dataset_dir = osp.join(src_dir, 'dataset') + if osp.isdir(old_dataset_dir): + # Such source cannot be represented in v2 directly. + # However, it can be considered a generated source with + # working tree data. + name = generate_next_name(list(new_tree_config.sources), + 'local_dataset', sep='-', default='1') + source = new_project.import_source(name, url=old_dataset_dir, + format=DEFAULT_FORMAT) + + # Make the source generated. It can only have local data. + source.url = '' + + new_project.save() + new_project.close() + + def __init__(self, path: Optional[str] = None, readonly=False): + if not path: + path = osp.curdir + found_path = self.find_project_dir(path) + if not found_path: + raise ProjectNotFoundError(path) + + old_config_path = osp.join(found_path, 'config.yaml') + if osp.isfile(old_config_path): + if Config.parse(old_config_path).format_version != 2: + raise OldProjectError() + + self._aux_dir = found_path + self._root_dir = osp.dirname(found_path) + + self._readonly = readonly + + # Force import errors on missing dependencies. + # + # TODO: maybe allow class use in some cases, which not require + # Git or DVC + GitWrapper.module() + DvcWrapper.module() + + self._git = GitWrapper(self._root_dir) + self._dvc = DvcWrapper(self._root_dir) + + self._working_tree = None + self._head_tree = None + + local_config = osp.join(self._aux_dir, ProjectLayout.conf_file) + if osp.isfile(local_config): + self._config = ProjectConfig.parse(local_config) + else: + self._config = ProjectConfig() + + self._env = Environment() + + plugins_dir = osp.join(self._aux_dir, ProjectLayout.plugins_dir) + if osp.isdir(plugins_dir): + self._env.load_plugins(plugins_dir) + + def _init_vcs(self): + # DVC requires Git to be initialized + if not self._git.initialized: + self._git.init() + self._git.ignore([ + ProjectLayout.cache_dir, + ], gitignore=osp.join(self._aux_dir, '.gitignore')) + self._git.ignore([]) # create the file + if not self._dvc.initialized: + self._dvc.init() + self._dvc.ignore([ + osp.join(self._aux_dir, ProjectLayout.cache_dir), + osp.join(self._aux_dir, ProjectLayout.working_tree_dir), + ]) + self._git.repo.index.remove( + osp.join(self._root_dir, '.dvc', 'plots'), r=True) + self.commit('Initial commit', allow_empty=True) + @classmethod - def load(cls, path): + @scoped + def init(cls, path) -> 'Project': + existing_project = cls.find_project_dir(path) + if existing_project: + raise ProjectAlreadyExists(path) + path = osp.abspath(path) - config_path = osp.join(path, PROJECT_DEFAULT_CONFIG.env_dir, - PROJECT_DEFAULT_CONFIG.project_filename) - config = Config.parse(config_path) - config.project_dir = path - config.project_filename = osp.basename(config_path) - return Project(config) - - def save(self, save_dir=None): - config = self.config - - if save_dir is None: - assert config.project_dir - project_dir = config.project_dir + if osp.basename(path) != ProjectLayout.aux_dir: + path = osp.join(path, ProjectLayout.aux_dir) + + project_dir = osp.dirname(path) + if not osp.isdir(project_dir): + on_error_do(rmtree, project_dir, ignore_errors=True) + + os.makedirs(path, exist_ok=True) + + on_error_do(rmtree, osp.join(project_dir, ProjectLayout.cache_dir), + ignore_errors=True) + on_error_do(rmtree, osp.join(project_dir, ProjectLayout.tmp_dir), + ignore_errors=True) + os.makedirs(osp.join(path, ProjectLayout.cache_dir)) + os.makedirs(osp.join(path, ProjectLayout.tmp_dir)) + + on_error_do(rmtree, osp.join(project_dir, '.git'), ignore_errors=True) + on_error_do(rmtree, osp.join(project_dir, '.dvc'), ignore_errors=True) + project = Project(path) + project._init_vcs() + + return project + + def close(self): + if self._dvc: + self._dvc.close() + self._dvc = None + + if self._git: + self._git.close() + self._git = None + + def __del__(self): + with suppress(Exception): + self.close() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + def save(self): + self._config.dump(osp.join(self._aux_dir, ProjectLayout.conf_file)) + + if self._working_tree: + self._working_tree.save() + + @property + def readonly(self) -> bool: + return self._readonly + + @property + def working_tree(self) -> Tree: + if self._working_tree is None: + self._working_tree = self.get_rev(None) + return self._working_tree + + @property + def head(self) -> Tree: + if self._head_tree is None: + self._head_tree = self.get_rev('HEAD') + return self._head_tree + + @property + def head_rev(self) -> Revision: + return self._git.head + + @property + def branch(self) -> str: + return self._git.branch + + @property + def config(self) -> Config: + return self._config + + @property + def env(self) -> Environment: + return self._env + + @property + def models(self) -> Dict[str, Model]: + return dict(self._config.models) + + def get_rev(self, rev: Union[None, Revision]) -> Tree: + """ + Reference conventions: + - None or "" - working dir + - "<40 symbols>" - revision hash + """ + + obj_type, obj_hash = self._parse_ref(rev) + assert obj_type == self._ObjectIdKind.tree, obj_type + + if self._is_working_tree_ref(obj_hash): + config_path = osp.join(self._aux_dir, + ProjectLayout.working_tree_dir, TreeLayout.conf_file) + if osp.isfile(config_path): + tree_config = TreeConfig.parse(config_path) + else: + tree_config = TreeConfig() + os.makedirs(osp.dirname(config_path), exist_ok=True) + tree_config.dump(config_path) + tree_config.config_path = config_path + tree_config.base_dir = osp.dirname(config_path) + tree = Tree(config=tree_config, project=self, rev=obj_hash) else: - project_dir = save_dir + if not self.is_rev_cached(obj_hash): + self._materialize_rev(obj_hash) + + rev_dir = self.cache_path(obj_hash) + tree_config = TreeConfig.parse(osp.join(rev_dir, + TreeLayout.conf_file)) + tree_config.base_dir = rev_dir + tree = Tree(config=tree_config, project=self, rev=obj_hash) + return tree + + def is_rev_cached(self, rev: Revision) -> bool: + obj_type, obj_hash = self._parse_ref(rev) + assert obj_type == self._ObjectIdKind.tree, obj_type + return self._is_cached(obj_hash) + + def is_obj_cached(self, obj_hash: ObjectId) -> bool: + return self._is_cached(obj_hash) or \ + self._can_retrieve_from_vcs_cache(obj_hash) + + @staticmethod + def _is_working_tree_ref(ref: Union[None, Revision, ObjectId]) -> bool: + return not ref + + class _ObjectIdKind(Enum): + # Project revision data. Currently, a Git commit hash. + tree = auto() + + # Source revision data. DVC directories and files. + blob = auto() + + def _parse_ref(self, ref: Union[None, Revision, ObjectId]) \ + -> Tuple[_ObjectIdKind, ObjectId]: + """ + Resolves the reference to an object hash. + """ - env_dir = osp.join(project_dir, config.env_dir) - save_dir = osp.abspath(env_dir) + if self._is_working_tree_ref(ref): + return self._ObjectIdKind.tree, ref - project_dir_existed = osp.exists(project_dir) - env_dir_existed = osp.exists(env_dir) try: - os.makedirs(save_dir, exist_ok=True) + obj_type, obj_hash = self._git.rev_parse(ref) + except Exception: # nosec - B110:try_except_pass + pass # Ignore git errors + else: + if obj_type != 'commit': + raise UnknownRefError(obj_hash) + + return self._ObjectIdKind.tree, obj_hash + + try: + assert self._dvc.is_hash(ref), ref + return self._ObjectIdKind.blob, ref + except Exception as e: + raise UnknownRefError(ref) from e + + def _materialize_rev(self, rev: Revision) -> str: + """ + Restores the revision tree data in the project cache from Git. + + Returns: cache object path + """ + # TODO: maybe avoid this operation by providing a virtual filesystem + # object + + # Allowed to be run when readonly, because it doesn't modify project + # data and doesn't hurt disk space. + + obj_dir = self.cache_path(rev) + if osp.isdir(obj_dir): + return obj_dir + + tree = self._git.get_tree(rev) + self._git.write_tree(tree, obj_dir) + return obj_dir + + def _is_cached(self, obj_hash: ObjectId): + return osp.isdir(self.cache_path(obj_hash)) + + def cache_path(self, obj_hash: ObjectId) -> str: + assert self._git.is_hash(obj_hash) or self._dvc.is_hash(obj_hash), obj_hash + if self._dvc.is_dir_hash(obj_hash): + obj_hash = obj_hash[:self._dvc.FILE_HASH_LEN] + + return osp.join(self._aux_dir, ProjectLayout.cache_dir, + obj_hash[:2], obj_hash[2:]) + + def _can_retrieve_from_vcs_cache(self, obj_hash: ObjectId): + if not self._dvc.is_dir_hash(obj_hash): + dir_check = self._dvc.is_cached( + obj_hash + self._dvc.DIR_HASH_SUFFIX) + else: + dir_check = False + return dir_check or self._dvc.is_cached(obj_hash) + + def source_data_dir(self, name: str) -> str: + return osp.join(self._root_dir, name) + + def _source_dvcfile_path(self, name: str, + root: Optional[str] = None) -> str: + """ + root - Path to the tree root directory. If not set, + the working tree is used. + """ + + if not root: + root = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) + return osp.join(root, TreeLayout.sources_dir, name, 'source.dvc') + + def _make_tmp_dir(self, suffix: Optional[str] = None): + project_tmp_dir = osp.join(self._aux_dir, ProjectLayout.tmp_dir) + os.makedirs(project_tmp_dir, exist_ok=True) + if suffix: + suffix = '_' + suffix + + return tempfile.TemporaryDirectory(suffix=suffix, dir=project_tmp_dir) + + def remove_cache_obj(self, ref: Union[Revision, ObjectId]): + if self.readonly: + raise ReadonlyProjectError() + + obj_type, obj_hash = self._parse_ref(ref) + + if self._is_cached(obj_hash): + rmtree(self.cache_path(obj_hash)) + + if obj_type == self._ObjectIdKind.tree: + # Revision metadata is cheap enough and needed to materialize + # the revision, so we keep it in the Git cache. + pass + elif obj_type == self._ObjectIdKind.blob: + self._dvc.remove_cache_obj(obj_hash) + else: + raise ValueError("Unexpected object type '%s'" % obj_type) + + def validate_source_name(self, name: str): + if not name: + raise ValueError("Source name cannot be empty") + + disallowed_symbols = r"[^\\ \.\~\-\w]" + found_wrong_symbols = re.findall(disallowed_symbols, name) + if found_wrong_symbols: + raise ValueError("Source name contains invalid symbols: %s" % + found_wrong_symbols) + + valid_filename = make_file_name(name) + if valid_filename != name: + raise ValueError("Source name contains " + "invalid symbols: %s" % (set(name) - set(valid_filename)) ) - config_path = osp.join(save_dir, config.project_filename) - config.dump(config_path) - except BaseException: - if not env_dir_existed: - shutil.rmtree(save_dir, ignore_errors=True) - if not project_dir_existed: - shutil.rmtree(project_dir, ignore_errors=True) - raise + if name.startswith('.'): + raise ValueError("Source name can't start with '.'") + + reserved_names = {'dataset', 'build', 'project'} + if name.lower() in reserved_names: + raise ValueError("Source name is reserved for internal use") + + @scoped + def _download_source(self, url: str, dst_dir: str, no_cache: bool = False): + assert url + assert dst_dir + + dvcfile = osp.join(dst_dir, 'source.dvc') + data_dir = osp.join(dst_dir, 'data') + + log.debug(f"Copying from '{url}' to '{data_dir}'") + + if osp.isdir(url): + copytree(url, data_dir) + elif osp.isfile(url): + os.makedirs(data_dir, exist_ok=True) + shutil.copy(url, data_dir) + else: + raise UnexpectedUrlError(url) + on_error_do(rmtree, data_dir, ignore_errors=True) + + log.debug("Done") + + obj_hash = self.compute_source_hash(data_dir, + dvcfile=dvcfile, no_cache=no_cache, allow_external=True) + if not no_cache: + log.debug("Data is added to DVC cache") + log.debug("Data hash: '%s'", obj_hash) + + return obj_hash, dvcfile, data_dir @staticmethod - def generate(save_dir, config=None): - config = Config(config) - config.project_dir = save_dir - project = Project(config) - project.save(save_dir) - return project + def _get_source_hash(dvcfile): + obj_hash = DvcWrapper.get_hash_from_dvcfile(dvcfile) + if obj_hash.endswith(DvcWrapper.DIR_HASH_SUFFIX): + obj_hash = obj_hash[:-len(DvcWrapper.DIR_HASH_SUFFIX)] + return obj_hash + + @scoped + def compute_source_hash(self, data_dir: str, dvcfile: Optional[str] = None, + no_cache: bool = True, allow_external: bool = True) -> ObjectId: + if not dvcfile: + tmp_dir = scope_add(self._make_tmp_dir()) + dvcfile = osp.join(tmp_dir, 'source.dvc') + + self._dvc.add(data_dir, dvc_path=dvcfile, no_commit=no_cache, + allow_external=allow_external) + obj_hash = self._get_source_hash(dvcfile) + return obj_hash + + def refresh_source_hash(self, source: str, + no_cache: bool = True) -> ObjectId: + """ + Computes and updates the source hash in the working directory. + + Returns: hash + """ + + if self.readonly: + raise ReadonlyProjectError() + + build_target = self.working_tree.build_targets[source] + source_dir = self.source_data_dir(source) + + if not osp.isdir(source_dir): + return None + + dvcfile = self._source_dvcfile_path(source) + os.makedirs(osp.dirname(dvcfile), exist_ok=True) + obj_hash = self.compute_source_hash(source_dir, + dvcfile=dvcfile, no_cache=no_cache) + + build_target.head.hash = obj_hash + + return obj_hash + + def _materialize_obj(self, obj_hash: ObjectId) -> str: + """ + Restores the object data in the project cache from DVC. + + Returns: cache object path + """ + # TODO: maybe avoid this operation by providing a virtual filesystem + # object + + # Allowed to be run when readonly, because it shouldn't hurt disk + # space, if object is materialized with symlinks. + + if not self._can_retrieve_from_vcs_cache(obj_hash): + raise MissingObjectError(obj_hash) + + dst_dir = self.cache_path(obj_hash) + if osp.isdir(dst_dir): + return dst_dir + + self._dvc.write_obj(obj_hash, dst_dir, allow_links=True) + return dst_dir + + @scoped + def import_source(self, name: str, url: Optional[str], + format: str, options: Optional[Dict] = None, + no_cache: bool = False, rpath: Optional[str] = None) -> Source: + """ + Adds a new source (dataset) to the working directory of the project. + + When 'path' is specified, will copy all the data from URL, but read + only the specified file. Required to support subtasks and subsets + in datasets. + + Parameters: + - name (str) - Name of the new source + - url (str) - URL of the new source. A path to a file or directory + - format (str) - Dataset format + - options (dict) - Options for the format Extractor + - no_cache (bool) - Don't put a copy of files into the project cache. + Can be used to reduce project cache size. + - rpath (str) - Used to specify a relative path to the dataset + inside of the directory pointed by URL. + + Returns: the new source config + """ + + if self.readonly: + raise ReadonlyProjectError() + + self.validate_source_name(name) + + if name in self.working_tree.sources: + raise SourceExistsError(name) + + data_dir = self.source_data_dir(name) + if osp.exists(data_dir): + if os.listdir(data_dir): + raise FileExistsError("Source directory '%s' already " + "exists" % data_dir) + os.rmdir(data_dir) + + if url: + url = osp.abspath(url) + if not osp.exists(url): + raise FileNotFoundError(url) + + if is_subpath(url, base=self._root_dir): + raise SourceUrlInsideProjectError() + + if rpath: + rpath = osp.normpath(osp.join(url, rpath)) + + if not osp.exists(rpath): + raise FileNotFoundError(rpath) + + if not is_subpath(rpath, base=url): + raise PathOutsideSourceError( + "Source data path is outside of the directory, " + "specified by source URL: '%s', '%s'" % (rpath, url)) + + rpath = osp.relpath(rpath, url) + else: + rpath = None + + config = Source({ + 'url': (url or '').replace('\\', '/'), + 'path': (rpath or '').replace('\\', '/'), + 'format': format, + 'options': options or {}, + }) + + if not config.is_generated: + dvcfile = self._source_dvcfile_path(name) + os.makedirs(osp.dirname(dvcfile), exist_ok=True) + + with self._make_tmp_dir() as tmp_dir: + obj_hash, tmp_dvcfile, tmp_data_dir = \ + self._download_source(url, tmp_dir, no_cache=no_cache) + + shutil.move(tmp_data_dir, data_dir) + on_error_do(rmtree, data_dir) + os.replace(tmp_dvcfile, dvcfile) + + config['hash'] = obj_hash + + self._git.ignore([data_dir]) + + config = self.working_tree.sources.add(name, config) + target = self.working_tree.build_targets.add_target(name) + target.root.hash = config.hash + + self.working_tree.save() + + return config + + def remove_source(self, name: str, force: bool = False, + keep_data: bool = True): + """ + Options: + - force (bool) - ignores errors and tries to wipe remaining data + - keep_data (bool) - leaves source data untouched + """ + + if self.readonly: + raise ReadonlyProjectError() + + if name not in self.working_tree.sources and not force: + raise UnknownSourceError(name) + + self.working_tree.sources.remove(name) + + data_dir = self.source_data_dir(name) + if not keep_data: + if osp.isdir(data_dir): + rmtree(data_dir) + + dvcfile = self._source_dvcfile_path(name) + if osp.isfile(dvcfile): + try: + rmfile(dvcfile) + except Exception: + if not force: + raise + + self.working_tree.build_targets.remove_target(name) + + self.working_tree.save() + + self._git.ignore([data_dir], mode='remove') + + def commit(self, message: str, no_cache: bool = False, + allow_empty: bool = False, allow_foreign: bool = False) -> Revision: + """ + Copies tree and objects from the working dir to the cache. + Creates a new commit. Moves the HEAD pointer to the new commit. + + Options: + - no_cache (bool) - don't put added dataset data into cache, + store only metainfo. Can be used to reduce storage size. + - allow_empty (bool) - allow commits with no changes. + - allow_foreign (bool) - allow commits with changes made not by Datumaro. + + Returns: the new commit hash + """ + + if self.readonly: + raise ReadonlyProjectError() + + statuses = self.status() + + if not allow_empty and not statuses: + raise EmptyCommitError() + + for t, s in statuses.items(): + if s == DiffStatus.foreign_modified: + # TODO: compute a patch and a new stage, remove allow_foreign + if allow_foreign: + log.warning("The source '%s' has been changed " + "without Datumaro. It will be saved, but it will " + "only be available for reproduction from the cache.", t) + else: + raise ForeignChangesError( + "The source '%s' is changed outside Datumaro. You can " + "restore the latest source revision with 'checkout' " + "command." % t) + + for s in self.working_tree.sources: + self.refresh_source_hash(s, no_cache=no_cache) + + wtree_dir = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) + self.working_tree.save() + self._git.add(wtree_dir, base=wtree_dir) + + extra_files = [ + osp.join(self._root_dir, '.dvc', '.gitignore'), + osp.join(self._root_dir, '.dvc', 'config'), + osp.join(self._root_dir, '.dvcignore'), + osp.join(self._root_dir, '.gitignore'), + osp.join(self._aux_dir, '.gitignore'), + ] + self._git.add(extra_files, base=self._root_dir) + + head = self._git.commit(message) + + rev_dir = self.cache_path(head) + copytree(wtree_dir, rev_dir) + for p in extra_files: + if osp.isfile(p): + dst_path = osp.join(rev_dir, osp.relpath(p, self._root_dir)) + os.makedirs(osp.dirname(dst_path), exist_ok=True) + shutil.copyfile(p, dst_path) + + self._head_tree = None + + return head @staticmethod - def import_from(path, format=None, env=None, **options): - if env is None: - env = Environment() - - if not format: - matches = env.detect_dataset(path) - if not matches: - raise NoMatchingFormatsError() - if 1 < len(matches): - raise MultipleFormatsMatchError(matches) - format = matches[0] - elif not env.is_format_known(format): - raise UnknownFormatError(format) - - if format in env.importers: - project = env.make_importer(format)(path, **options) - elif format in env.extractors: - project = Project(env=env) - project.add_source('source', { - 'url': path, - 'format': format, - 'options': options, - }) + def _move_dvc_dir(src_dir, dst_dir): + for name in {'config', '.gitignore'}: + os.replace(osp.join(src_dir, name), osp.join(dst_dir, name)) + + def checkout(self, rev: Union[None, Revision] = None, + sources: Union[None, str, Iterable[str]] = None, + force: bool = False): + """ + Copies tree and objects from cache to working tree. + + Sets HEAD to the specified revision, unless targets specified. + When sources specified, only copies objects from cache to working tree. + """ + + if self.readonly: + raise ReadonlyProjectError() + + if isinstance(sources, str): + sources = {sources} + elif sources is None: + sources = {} else: - raise UnknownFormatError(format) - return project + sources = set(sources) - def __init__(self, config=None, env=None): - self.config = Config(config, - fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA) - if env is None: - env = Environment(self.config) - env.models.batch_register(self.config.models) - env.sources.batch_register(self.config.sources) - env.load_plugins(osp.join(self.config.project_dir, - self.config.env_dir, self.config.plugins_dir)) - elif config is not None: - raise ValueError("env can only be provided when no config provided") - self.env = env - - def make_dataset(self): - return ProjectDataset(self) - - def add_source(self, name, value=None): - if value is None or isinstance(value, (dict, Config)): - value = Source(value) - self.config.sources[name] = value - self.env.sources.register(name, value) - - def remove_source(self, name): - self.config.sources.remove(name) - self.env.sources.unregister(name) - - def get_source(self, name): - try: - return self.config.sources[name] - except KeyError: - raise KeyError("Source '%s' is not found" % name) + if sources: + rev_tree = self.get_rev(rev or 'HEAD') + + # Check targets + for s in sources: + if not s in rev_tree.sources: + raise UnknownSourceError(s) + + rev_dir = rev_tree.config.base_dir + with self._make_tmp_dir() as tmp_dir: + dvcfiles = [] + + for s in sources: + dvcfile = self._source_dvcfile_path(s, root=rev_dir) + + tmp_dvcfile = osp.join(tmp_dir, s + '.dvc') + with open(dvcfile) as f: + conf = self._dvc.yaml_parser.load(f) - def get_subsets(self): - return self.config.subsets + conf['wdir'] = self._root_dir - def set_subsets(self, value): - if not value: - self.config.remove('subsets') + with open(tmp_dvcfile, 'w') as f: + self._dvc.yaml_parser.dump(conf, f) + + dvcfiles.append(tmp_dvcfile) + + self._dvc.checkout(dvcfiles) + + self._git.ignore(sources) + + for s in sources: + self.working_tree.config.sources[s] = \ + rev_tree.config.sources[s] + self.working_tree.config.build_targets[s] = \ + rev_tree.config.build_targets[s] + + self.working_tree.save() else: - self.config.subsets = value + # Check working tree for unsaved changes, + # set HEAD to the revision + # write revision tree to working tree + wtree_dir = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) + self._git.checkout(rev, dst_dir=wtree_dir, clean=True, force=force) + self._move_dvc_dir(osp.join(wtree_dir, '.dvc'), + osp.join(self._root_dir, '.dvc')) - def add_model(self, name, value=None): - if value is None or isinstance(value, (dict, Config)): - value = Model(value) - self.env.register_model(name, value) - self.config.models[name] = value + self._working_tree = None - def get_model(self, name): - try: - return self.env.models.get(name) - except KeyError: - raise KeyError("Model '%s' is not found" % name) - - def remove_model(self, name): - self.config.models.remove(name) - self.env.unregister_model(name) - - def make_executable_model(self, name): - model = self.get_model(name) - return self.env.make_launcher(model.launcher, - **model.options, model_dir=osp.join( - self.config.project_dir, self.local_model_dir(name))) - - def make_source_project(self, name): - source = self.get_source(name) - - config = Config(self.config) - config.remove('sources') - config.remove('subsets') - project = Project(config) - project.add_source(name, source) - return project + # Restore sources from the commit. + # Work with the working tree instead of cache, to + # avoid extra memory use from materializing + # the head commit sources in the cache + rev_tree = self.working_tree + with self._make_tmp_dir() as tmp_dir: + dvcfiles = [] + + for s in rev_tree.sources: + dvcfile = self._source_dvcfile_path(s) + + tmp_dvcfile = osp.join(tmp_dir, s + '.dvc') + with open(dvcfile) as f: + conf = self._dvc.yaml_parser.load(f) + + conf['wdir'] = self._root_dir + + with open(tmp_dvcfile, 'w') as f: + self._dvc.yaml_parser.dump(conf, f) + + dvcfiles.append(tmp_dvcfile) + + self._dvc.checkout(dvcfiles) + + os.replace(osp.join(wtree_dir, '.gitignore'), + osp.join(self._root_dir, '.gitignore')) + os.replace(osp.join(wtree_dir, '.dvcignore'), + osp.join(self._root_dir, '.dvcignore')) + + self._working_tree = None + + def is_ref(self, ref: Union[None, str]) -> bool: + if self._is_working_tree_ref(ref): + return True + return self._git.is_ref(ref) + + def has_commits(self) -> bool: + return self._git.has_commits() + + def status(self) -> Dict[str, DiffStatus]: + wd = self.working_tree + + if not self.has_commits(): + return { s: DiffStatus.added for s in wd.sources } + + + head = self.head + + changed_targets = {} + + for t_name, wd_target in wd.build_targets.items(): + if t_name == ProjectBuildTargets.MAIN_TARGET: + continue + + if osp.isdir(self.source_data_dir(t_name)): + old_hash = wd_target.head.hash + new_hash = self.refresh_source_hash(t_name) + + if old_hash != new_hash: + changed_targets[t_name] = DiffStatus.foreign_modified + + for t_name in set(head.build_targets) | set(wd.build_targets): + if t_name == ProjectBuildTargets.MAIN_TARGET: + continue + if t_name in changed_targets: + continue + + head_target = head.build_targets.get(t_name) + wd_target = wd.build_targets.get(t_name) + + status = None + + if head_target is None: + status = DiffStatus.added + elif wd_target is None: + status = DiffStatus.removed + else: + if head_target != wd_target: + status = DiffStatus.modified + elif not osp.isdir(self.source_data_dir(t_name)): + status = DiffStatus.missing + + if status: + changed_targets[t_name] = status + + return changed_targets + + def history(self, max_count=10) -> List[Tuple[Revision, str]]: + return [(c.hexsha, c.message) for c, _ in self._git.log(max_count)] + + def diff(self, rev_a: Union[Tree, Revision], + rev_b: Union[Tree, Revision]) -> Dict[str, DiffStatus]: + """ + Compares 2 revision trees. + + Returns: { target_name: status } for changed targets + """ + + if rev_a == rev_b: + return {} + + if isinstance(rev_a, str): + tree_a = self.get_rev(rev_a) + else: + tree_a = rev_a + + if isinstance(rev_b, str): + tree_b = self.get_rev(rev_b) + else: + tree_b = rev_b + + + changed_targets = {} + + for t_name in set(tree_a.build_targets) | set(tree_b.build_targets): + if t_name == ProjectBuildTargets.MAIN_TARGET: + continue + + head_target = tree_a.build_targets.get(t_name) + wd_target = tree_b.build_targets.get(t_name) + + status = None + + if head_target is None: + status = DiffStatus.added + elif wd_target is None: + status = DiffStatus.removed + else: + if head_target != wd_target: + status = DiffStatus.modified + + if status: + changed_targets[t_name] = status + + return changed_targets + + def model_data_dir(self, name: str) -> str: + return osp.join(self._aux_dir, ProjectLayout.models_dir, name) + + def make_model(self, name: str) -> Launcher: + model = self._config.models[name] + model_dir = self.model_data_dir(name) + if not osp.isdir(model_dir): + model_dir = None + return self._env.make_launcher(model.launcher, + **model.options, model_dir=model_dir) + + def add_model(self, name: str, launcher: str, options = None) -> Model: + if self.readonly: + raise ReadonlyProjectError() + + if not launcher in self.env.launchers: + raise KeyError("Unknown launcher '%s'" % launcher) + + if not name: + raise ValueError("Model name can't be empty") + + if name in self.models: + raise KeyError("Model '%s' already exists" % name) + + return self._config.models.set(name, { + 'launcher': launcher, + 'options': options or {} + }) + + def remove_model(self, name: str): + if self.readonly: + raise ReadonlyProjectError() - def local_model_dir(self, model_name): - return osp.join( - self.config.env_dir, self.config.models_dir, model_name) + if name in self.models: + raise KeyError("Unknown model '%s'" % name) - def local_source_dir(self, source_name): - return osp.join(self.config.sources_dir, source_name) + data_dir = self.model_data_dir(name) + if osp.isdir(data_dir): + rmtree(data_dir) diff --git a/datumaro/components/validator.py b/datumaro/components/validator.py index fa8bc4441a..9cc0fa8885 100644 --- a/datumaro/components/validator.py +++ b/datumaro/components/validator.py @@ -5,6 +5,7 @@ from enum import Enum, auto from typing import Dict, List +from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset import IDataset @@ -19,7 +20,7 @@ class TaskType(Enum): segmentation = auto() -class Validator: +class Validator(CliPlugin): def validate(self, dataset: IDataset) -> Dict: """ Returns the validation results of a dataset based on task type. diff --git a/datumaro/plugins/coco_format/converter.py b/datumaro/plugins/coco_format/converter.py index bb6cf5feac..90a5e7f9c9 100644 --- a/datumaro/plugins/coco_format/converter.py +++ b/datumaro/plugins/coco_format/converter.py @@ -519,20 +519,20 @@ def _split_tasks_string(s): @classmethod def build_cmdline_parser(cls, **kwargs): + kwargs['description'] = """ + Segmentation mask modes (--segmentation-mode):|n + - '{sm.guess.name}': guess the mode for each instance,|n + |s|suse 'is_crowd' attribute as hint|n + - '{sm.polygons.name}': save polygons,|n + |s|smerge and convert masks, prefer polygons|n + - '{sm.mask.name}': save masks,|n + |s|smerge and convert polygons, prefer masks + """.format(sm=SegmentationMode) parser = super().build_cmdline_parser(**kwargs) parser.add_argument('--segmentation-mode', choices=[m.name for m in SegmentationMode], default=SegmentationMode.guess.name, - help=""" - Save mode for instance segmentation:|n - - '{sm.guess.name}': guess the mode for each instance,|n - |s|suse 'is_crowd' attribute as hint|n - - '{sm.polygons.name}': save polygons,|n - |s|smerge and convert masks, prefer polygons|n - - '{sm.mask.name}': save masks,|n - |s|smerge and convert polygons, prefer masks|n - Default: %(default)s. - """.format(sm=SegmentationMode)) + help="Save mode for instance segmentation (default: %(default)s)") parser.add_argument('--crop-covered', action='store_true', help="Crop covered segments so that background objects' " "segmentation was more accurate (default: %(default)s)") diff --git a/datumaro/plugins/coco_format/importer.py b/datumaro/plugins/coco_format/importer.py index a915e81faf..c90e1e5e01 100644 --- a/datumaro/plugins/coco_format/importer.py +++ b/datumaro/plugins/coco_format/importer.py @@ -30,9 +30,6 @@ def detect(cls, path): return len(cls.find_sources(path)) != 0 def __call__(self, path, **extra_params): - from datumaro.components.project import Project # cyclic import - project = Project() - subsets = self.find_sources(path) if len(subsets) == 0: @@ -52,6 +49,7 @@ def __call__(self, path, **extra_params): "Only one type will be used: %s" \ % (", ".join(t.name for t in ann_types), selected_ann_type.name)) + sources = [] for ann_files in subsets.values(): for ann_type, ann_file in ann_files.items(): if ann_type in conflicting_types: @@ -61,14 +59,13 @@ def __call__(self, path, **extra_params): continue log.info("Found a dataset at '%s'" % ann_file) - source_name = osp.splitext(osp.basename(ann_file))[0] - project.add_source(source_name, { + sources.append({ 'url': ann_file, 'format': self._TASKS[ann_type], 'options': dict(extra_params), }) - return project + return sources @classmethod def find_sources(cls, path): diff --git a/datumaro/plugins/datumaro_format/converter.py b/datumaro/plugins/datumaro_format/converter.py index 270ea60a2b..c6e78142a3 100644 --- a/datumaro/plugins/datumaro_format/converter.py +++ b/datumaro/plugins/datumaro_format/converter.py @@ -349,17 +349,3 @@ def patch(cls, dataset, patch, save_dir, **kwargs): DatumaroPath.RELATED_IMAGES_DIR, item.subset, item.id) if osp.isdir(related_images_path): shutil.rmtree(related_images_path) - -class DatumaroProjectConverter(Converter): - @classmethod - def convert(cls, extractor, save_dir, **kwargs): - os.makedirs(save_dir, exist_ok=True) - - from datumaro.components.project import Project - project = Project.generate(save_dir, - config=kwargs.pop('project_config', None)) - - DatumaroConverter.convert(extractor, - save_dir=osp.join( - project.config.project_dir, project.config.dataset_dir), - **kwargs) diff --git a/datumaro/plugins/datumaro_format/extractor.py b/datumaro/plugins/datumaro_format/extractor.py index 0f70ce21a3..300b6c1f57 100644 --- a/datumaro/plugins/datumaro_format/extractor.py +++ b/datumaro/plugins/datumaro_format/extractor.py @@ -192,4 +192,5 @@ def _load_annotations(item): class DatumaroImporter(Importer): @classmethod def find_sources(cls, path): - return cls._find_sources_recursive(path, '.json', 'datumaro') + return cls._find_sources_recursive(path, '.json', 'datumaro', + dirname='annotations') diff --git a/datumaro/plugins/kitti_format/importer.py b/datumaro/plugins/kitti_format/importer.py index f898350d9e..8dd77e8521 100644 --- a/datumaro/plugins/kitti_format/importer.py +++ b/datumaro/plugins/kitti_format/importer.py @@ -7,7 +7,6 @@ import os.path as osp from datumaro.components.extractor import Importer -from datumaro.util.log_utils import logging_disabled from .format import KittiPath, KittiTask @@ -18,15 +17,7 @@ class KittiImporter(Importer): KittiTask.detection: ('kitti_detection', KittiPath.LABELS_DIR), } - @classmethod - def detect(cls, path): - with logging_disabled(log.WARN): - return len(cls.find_sources(path)) != 0 - def __call__(self, path, **extra_params): - from datumaro.components.project import Project # cyclic import - project = Project() - subsets = self.find_sources(path) if len(subsets) == 0: @@ -44,6 +35,7 @@ def __call__(self, path, **extra_params): "Only one type will be used: %s" \ % (", ".join(t.name for t in ann_types), selected_ann_type.name)) + sources = [] for ann_files in subsets.values(): for ann_type, ann_file in ann_files.items(): if ann_type in conflicting_types: @@ -53,14 +45,13 @@ def __call__(self, path, **extra_params): continue log.info("Found a dataset at '%s'" % ann_file) - source_name = osp.splitext(osp.basename(ann_file))[0] - project.add_source(source_name, { + sources.append({ 'url': ann_file, 'format': ann_type, 'options': dict(extra_params), }) - return project + return sources @classmethod def find_sources(cls, path): diff --git a/datumaro/plugins/validators.py b/datumaro/plugins/validators.py index aadb457a4c..315a0e0412 100644 --- a/datumaro/plugins/validators.py +++ b/datumaro/plugins/validators.py @@ -48,21 +48,29 @@ class _TaskValidator(Validator, CliPlugin): @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) - parser.add_argument('-fs', '--few_samples_thr', default=1, type=int, - help="Threshold for giving a warning for minimum number of" - "samples per class") - parser.add_argument('-ir', '--imbalance_ratio_thr', default=50, type=int, - help="Threshold for giving data imbalance warning;" - "IR(imbalance ratio) = majority/minority") - parser.add_argument('-m', '--far_from_mean_thr', default=5.0, type=float, - help="Threshold for giving a warning that data is far from mean;" - "A constant used to define mean +/- k * standard deviation;") - parser.add_argument('-dr', '--dominance_ratio_thr', default=0.8, type=float, - help="Threshold for giving a warning for bounding box imbalance;" - "Dominace_ratio = ratio of Top-k bin to total in histogram;") - parser.add_argument('-k', '--topk_bins', default=0.1, type=float, + parser.add_argument('-fs', '--few-samples-thr', + default=1, type=int, + help="Threshold for giving a warning for minimum number of " + "samples per class (default: %(default)s)") + parser.add_argument('-ir', '--imbalance-ratio-thr', + default=50, type=int, + help="Threshold for giving data imbalance warning. " + "IR(imbalance ratio) = majority/minority " + "(default: %(default)s)") + parser.add_argument('-m', '--far-from-mean-thr', + default=5.0, type=float, + help="Threshold for giving a warning that data is far from mean. " + "A constant used to define mean +/- k * standard deviation " + "(default: %(default)s)") + parser.add_argument('-dr', '--dominance-ratio-thr', + default=0.8, type=float, + help="Threshold for giving a warning for bounding box imbalance. " + "Dominace_ratio = ratio of Top-k bin to total in histogram " + "(default: %(default)s)") + parser.add_argument('-k', '--topk-bins', default=0.1, type=float, help="Ratio of bins with the highest number of data" - "to total bins in the histogram; [0, 1]; 0.1 = 10%;") + "to total bins in the histogram. A value in the range [0, 1] " + "(default: %(default)s)") return parser def __init__(self, task_type, few_samples_thr=None, diff --git a/datumaro/plugins/voc_format/importer.py b/datumaro/plugins/voc_format/importer.py index fffd4e66d7..46cd1861d9 100644 --- a/datumaro/plugins/voc_format/importer.py +++ b/datumaro/plugins/voc_format/importer.py @@ -18,33 +18,6 @@ class VocImporter(Importer): VocTask.action_classification: ('voc_action', 'Action'), } - def __call__(self, path, **extra_params): - from datumaro.components.project import Project # cyclic import - project = Project() - - subsets = self.find_sources(path) - if len(subsets) == 0: - raise Exception("Failed to find 'voc' dataset at '%s'" % path) - - for config in subsets: - subset_path = config['url'] - extractor_type = config['format'] - - task = extractor_type.split('_')[1] - - opts = dict(config.get('options') or {}) - opts.update(extra_params) - - project.add_source('%s-%s' % - (task, osp.splitext(osp.basename(subset_path))[0]), - { - 'url': subset_path, - 'format': extractor_type, - 'options': opts, - }) - - return project - @classmethod def find_sources(cls, path): subsets = [] diff --git a/datumaro/util/command_targets.py b/datumaro/util/command_targets.py deleted file mode 100644 index 7a94074a29..0000000000 --- a/datumaro/util/command_targets.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (C) 2019-2021 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from enum import Enum, auto -import argparse - -from datumaro.components.project import Project -from datumaro.util.image import load_image - - -class TargetKinds(Enum): - project = auto() - source = auto() - external_dataset = auto() - inference = auto() - image = auto() - -def is_project_name(value, project): - return value == project.config.project_name - -def is_project_path(value): - if value: - try: - Project.load(value) - return True - except Exception: # nosec - disable B110:try_except_pass check - pass - return False - -def is_project(value, project=None): - if is_project_path(value): - return True - elif project is not None: - return is_project_name(value, project) - - return False - -def is_source(value, project=None): - if project is not None: - try: - project.get_source(value) - return True - except KeyError: - pass - - return False - -def is_external_source(value): - return False - -def is_inference_path(value): - return False - -def is_image_path(value): - try: - return load_image(value) is not None - except Exception: - return False - - -class Target: - def __init__(self, kind, test, is_default=False, name=None): - self.kind = kind - self.test = test - self.is_default = is_default - self.name = name - - def _get_fields(self): - return [self.kind, self.test, self.is_default, self.name] - - def __str__(self): - return self.name or str(self.kind) - - def __len__(self): - return len(self._get_fields()) - - def __iter__(self): - return iter(self._get_fields()) - -def ProjectTarget(kind=TargetKinds.project, test=None, - is_default=False, name='project name or path', - project=None): - if test is None: - test = lambda v: is_project(v, project=project) - return Target(kind, test, is_default, name) - -def SourceTarget(kind=TargetKinds.source, test=None, - is_default=False, name='source name', - project=None): - if test is None: - test = lambda v: is_source(v, project=project) - return Target(kind, test, is_default, name) - -def ExternalDatasetTarget(kind=TargetKinds.external_dataset, - test=is_external_source, - is_default=False, name='external dataset path'): - return Target(kind, test, is_default, name) - -def InferenceTarget(kind=TargetKinds.inference, test=is_inference_path, - is_default=False, name='inference path'): - return Target(kind, test, is_default, name) - -def ImageTarget(kind=TargetKinds.image, test=is_image_path, - is_default=False, name='image path'): - return Target(kind, test, is_default, name) - - -def target_selector(*targets): - def selector(value): - for (kind, test, is_default, _) in targets: - if (is_default and (value == '' or value is None)) or test(value): - return (kind, value) - raise argparse.ArgumentTypeError('Value should be one of: %s' \ - % (', '.join([str(t) for t in targets]))) - return selector diff --git a/datumaro/util/image.py b/datumaro/util/image.py index fa17c8716c..e6c93ef032 100644 --- a/datumaro/util/image.py +++ b/datumaro/util/image.py @@ -25,7 +25,7 @@ class _IMAGE_BACKENDS(Enum): try: importlib.import_module('cv2') _IMAGE_BACKEND = _IMAGE_BACKENDS.cv2 -except ImportError: +except ModuleNotFoundError: import PIL _IMAGE_BACKEND = _IMAGE_BACKENDS.PIL _image_loading_errors = (*_image_loading_errors, PIL.UnidentifiedImageError) @@ -179,11 +179,11 @@ def decode_image(image_bytes, dtype=np.float32): def find_images(dirpath: str, exts: Union[str, Iterable[str]] = None, recursive: bool = False, max_depth: int = None) -> Iterator[str]: if isinstance(exts, str): - exts = ['.' + exts.lower().lstrip('.')] + exts = {'.' + exts.lower().lstrip('.')} elif exts is None: exts = IMAGE_EXTENSIONS else: - exts = list('.' + e.lower().lstrip('.') for e in exts) + exts = {'.' + e.lower().lstrip('.') for e in exts} def _check_image_ext(filename: str): dotpos = filename.rfind('.') @@ -201,6 +201,10 @@ def _check_image_ext(filename: str): yield osp.join(d, filename) +def is_image(path: str): + trunk, ext = osp.splitext(osp.basename(path)) + return trunk and ext.lower() in IMAGE_EXTENSIONS and \ + osp.isfile(path) class lazy_image: def __init__(self, path, loader=None, cache=None): @@ -295,6 +299,8 @@ def has_size(self) -> bool: @property def size(self) -> Optional[Tuple[int, int]]: + "Returns (H, W)" + if self._size is None: try: data = self.data @@ -317,13 +323,16 @@ def __eq__(self, other): not self.has_data) def save(self, path): - src_ext = self.ext.lower() - dst_ext = osp.splitext(osp.basename(path))[1].lower() + cur_path = osp.abspath(self.path) + path = osp.abspath(path) + + cur_ext = self.ext.lower() + new_ext = osp.splitext(osp.basename(path))[1].lower() os.makedirs(osp.dirname(path), exist_ok=True) - if src_ext == dst_ext and osp.isfile(self.path): - if self.path != path: - shutil.copyfile(self.path, path) + if cur_ext == new_ext and osp.isfile(cur_path): + if cur_path != path: + shutil.copyfile(cur_path, path) else: save_image(path, self.data) @@ -369,14 +378,17 @@ def __eq__(self, other): not self.has_data) def save(self, path): - src_ext = self.ext.lower() - dst_ext = osp.splitext(osp.basename(path))[1].lower() + cur_path = osp.abspath(self.path) + path = osp.abspath(path) + + cur_ext = self.ext.lower() + new_ext = osp.splitext(osp.basename(path))[1].lower() os.makedirs(osp.dirname(path), exist_ok=True) - if src_ext == dst_ext and osp.isfile(self.path): - if self.path != path: - shutil.copyfile(self.path, path) - elif src_ext == dst_ext: + if cur_ext == new_ext and osp.isfile(cur_path): + if cur_path != path: + shutil.copyfile(cur_path, path) + elif cur_ext == new_ext: with open(path, 'wb') as f: f.write(self.get_bytes()) else: diff --git a/datumaro/util/log_utils.py b/datumaro/util/log_utils.py index 4533a7e483..62e5c35c2c 100644 --- a/datumaro/util/log_utils.py +++ b/datumaro/util/log_utils.py @@ -1,8 +1,9 @@ -# Copyright (C) 2020 Intel Corporation +# Copyright (C) 2020-2021 Intel Corporation # # SPDX-License-Identifier: MIT from contextlib import contextmanager +from io import StringIO import logging @@ -14,3 +15,21 @@ def logging_disabled(max_level=logging.CRITICAL): yield finally: logging.disable(previous_level) + +@contextmanager +def catch_logs(logger=None): + logger = logging.getLogger(logger) + + old_propagate = logger.propagate + prev_handlers = logger.handlers + + stream = StringIO() + handler = logging.StreamHandler(stream) + logger.handlers = [handler] + logger.propagate = False + + try: + yield stream + finally: + logger.handlers = prev_handlers + logger.propagate = old_propagate diff --git a/datumaro/util/os_util.py b/datumaro/util/os_util.py index 1a584b9623..2c3ea42a0f 100644 --- a/datumaro/util/os_util.py +++ b/datumaro/util/os_util.py @@ -5,11 +5,32 @@ from contextlib import ( ExitStack, contextmanager, redirect_stderr, redirect_stdout, ) +from io import StringIO +from typing import Iterable, Optional import importlib import os import os.path as osp +import re +import shutil import subprocess import sys +import unicodedata + +try: + # Declare functions to remove files and directories. + # + # Use rmtree from GitPython to avoid the problem with removal of + # readonly files on Windows, which Git uses extensively + # It double checks if a file cannot be removed because of readonly flag + from git.util import rmfile, rmtree # pylint: disable=unused-import + import git.util + git.util.HIDE_WINDOWS_KNOWN_ERRORS = False + +except ModuleNotFoundError: + from os import remove as rmfile # pylint: disable=unused-import + from shutil import rmtree as rmtree # pylint: disable=unused-import + +from . import cast DEFAULT_MAX_DEPTH = 10 @@ -31,8 +52,6 @@ def import_foreign_module(name, path, package=None): sys.modules.pop(name, None) # remove from cache module = importlib.import_module(name, package=package) sys.modules.pop(name) # remove from cache - except Exception: - raise finally: sys.path = default_path return module @@ -49,6 +68,70 @@ def walk(path, max_depth=None): yield dirpath, dirnames, filenames +def copytree(src, dst): + # Serves as a replacement for shutil.copytree(). + # + # Shutil works very slow pre 3.8 + # https://docs.python.org/3/library/shutil.html#platform-dependent-efficient-copy-operations + # https://bugs.python.org/issue33671 + + if sys.version_info >= (3, 8): + shutil.copytree(src, dst) + return + + assert src and dst + src = osp.abspath(src) + dst = osp.abspath(dst) + + if not osp.isdir(src): + raise FileNotFoundError("Source directory '%s' doesn't exist" % src) + + if osp.isdir(dst): + raise FileExistsError("Destination directory '%s' already exists" % dst) + + dst_basedir = osp.dirname(dst) + if dst_basedir: + os.makedirs(dst_basedir, exist_ok=True) + + try: + if sys.platform == 'windows': + # Ignore + # B603: subprocess_without_shell_equals_true + # B607: start_process_with_partial_path + # In this case we control what is called and command arguments + # PATH overriding is considered low risk + subprocess.check_output(["xcopy", src, dst, # nosec + "/s", "/e", "/q", "/y", "/i"], + stderr=subprocess.STDOUT, universal_newlines=True) + elif sys.platform == 'linux': + # As above + subprocess.check_output(["cp", "-r", '--', src, dst], # nosec + stderr=subprocess.STDOUT, universal_newlines=True) + else: + shutil.copytree(src, dst) + except subprocess.CalledProcessError as e: + raise Exception("Failed to copy data. The command '%s' " + "has failed with the following output: '%s'" % (e.cmd, e.stdout)) \ + from e + +@contextmanager +def suppress_output(stdout: bool = True, stderr: bool = False): + with open(os.devnull, 'w') as devnull, ExitStack() as es: + if stdout: + es.enter_context(redirect_stdout(devnull)) + elif stderr: + es.enter_context(redirect_stderr(devnull)) + + yield + +@contextmanager +def catch_output(): + stdout = StringIO() + stderr = StringIO() + + with redirect_stdout(stdout), redirect_stderr(stderr): + yield stdout, stderr + def dir_items(path, ext, truncate_ext=False): items = [] for f in os.listdir(path): @@ -75,31 +158,59 @@ def split_path(path): return parts -@contextmanager -def suppress_output(stdout: bool = True, stderr: bool = False): - with open(os.devnull, 'w') as devnull: - es = ExitStack() - - if stdout: - es.enter_context(redirect_stdout(devnull)) - - if stderr: - es.enter_context(redirect_stderr(devnull)) +def is_subpath(path: str, base: str) -> bool: + """ + Tests if a path is subpath of another path or the paths are equal. + """ - with es: - yield + base = osp.abspath(base) + path = osp.abspath(path) + return osp.join(path, '').startswith(osp.join(base, '')) -def make_file_name(s): +def make_file_name(s: str) -> str: # adapted from # https://docs.djangoproject.com/en/2.1/_modules/django/utils/text/#slugify """ Normalizes string, converts to lowercase, removes non-alpha characters, and converts spaces to hyphens. """ - import re - import unicodedata s = unicodedata.normalize('NFKD', s).encode('ascii', 'ignore') s = s.decode() s = re.sub(r'[^\w\s-]', '', s).strip().lower() s = re.sub(r'[-\s]+', '-', s) return s + +def generate_next_name(names: Iterable[str], basename: str, + sep: str = '.', suffix: str = '', default: Optional[str] = None) -> str: + """ + Generates the "next" name by appending a next index to the occurrence + of the basename with the highest index in the input collection. + + Returns: next string name + + Example: + + Inputs: + name_abc + name_base + name_base1 + name_base5 + + Basename: name_base + + Output: name_base6 + """ + + pattern = re.compile(r'%s(?:%s(\d+))?%s' % \ + tuple(map(re.escape, [basename, sep, suffix]))) + matches = [match for match in (pattern.match(n) for n in names) if match] + + max_idx = max([cast(match[1], int, 0) for match in matches], default=None) + if max_idx is None: + if default is not None: + idx = sep + str(default) + else: + idx = '' + else: + idx = sep + str(max_idx + 1) + return basename + idx + suffix diff --git a/datumaro/util/test_utils.py b/datumaro/util/test_utils.py index 039735d764..5933db2b70 100644 --- a/datumaro/util/test_utils.py +++ b/datumaro/util/test_utils.py @@ -3,7 +3,8 @@ # SPDX-License-Identifier: MIT from enum import Enum, auto -from typing import Collection, Union +from glob import glob +from typing import Collection, Optional, Union import inspect import os import os.path as osp @@ -11,18 +12,10 @@ from typing_extensions import Literal -try: - # Use rmtree from GitPython to avoid the problem with removal of - # readonly files on Windows, which Git uses extensively - # It double checks if a file cannot be removed because of readonly flag - from git.util import rmfile, rmtree -except ImportError: - from shutil import rmtree - from os import remove as rmfile - from datumaro.components.annotation import AnnotationType from datumaro.components.dataset import Dataset, IDataset from datumaro.util import filter_dict, find +from datumaro.util.os_util import rmfile, rmtree class Dimensions(Enum): @@ -40,13 +33,11 @@ def __init__(self, path, is_dir=False): def __enter__(self): return self.path - # pylint: disable=redefined-builtin - def __exit__(self, type=None, value=None, traceback=None): + def __exit__(self, exc_type=None, exc_value=None, traceback=None): if self.is_dir: rmtree(self.path) else: rmfile(self.path) - # pylint: enable=redefined-builtin class TestDir(FileRemover): """ @@ -55,18 +46,35 @@ class TestDir(FileRemover): Usage: - with TestDir() as test_dir: - ... + with TestDir() as test_dir: + ... """ - def __init__(self, path=None): + def __init__(self, path: Optional[str] = None, frame_id: int = 2): + if not path: + prefix = f'temp_{current_function_name(frame_id)}-' + else: + prefix = None + self._prefix = prefix + + super().__init__(path, is_dir=True) + + def __enter__(self) -> str: + """ + Creates a test directory. + + Returns: path to the directory + """ + + path = self.path + if path is None: - path = osp.abspath('temp_%s-' % current_function_name(2)) - path = tempfile.mkdtemp(dir=os.getcwd(), prefix=path) + path = tempfile.mkdtemp(dir=os.getcwd(), prefix=self._prefix) + self.path = path else: os.makedirs(path, exist_ok=False) - super().__init__(path, is_dir=True) + return path def compare_categories(test, expected, actual): test.assertEqual( @@ -226,6 +234,26 @@ def test_save_and_load(test, source_dataset, converter, test_dir, importer, compare = compare_datasets compare(test, expected=target_dataset, actual=parsed_dataset, **kwargs) +def compare_dirs(test, expected: str, actual: str): + """ + Compares file and directory structures in the given directories. + Empty directories are skipped. + """ + skip_empty_dirs = True + + for a_path in glob(osp.join(expected, '**', '*'), recursive=True): + rel_path = osp.relpath(a_path, expected) + b_path = osp.join(actual, rel_path) + if osp.isdir(a_path): + if not (skip_empty_dirs and not os.listdir(a_path)): + test.assertTrue(osp.isdir(b_path), rel_path) + continue + + test.assertTrue(osp.isfile(b_path), rel_path) + with open(a_path, 'rb') as a_file, \ + open(b_path, 'rb') as b_file: + test.assertEqual(a_file.read(), b_file.read(), rel_path) + def run_datum(test, *args, expected_code=0): from datumaro.cli.__main__ import main test.assertEqual(expected_code, main(args), str(args)) diff --git a/pytest.ini b/pytest.ini index 976c84825b..d48081e927 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,8 @@ [pytest] python_classes = python_functions = + +addopts = + # FIXME: Disable capture to avoid an infinite loop when tests use logging + # https://github.com/pytest-dev/pytest/issues/5502 + --capture=no \ No newline at end of file diff --git a/requirements-core.txt b/requirements-core.txt index 1392ff330f..1a0113448b 100644 --- a/requirements-core.txt +++ b/requirements-core.txt @@ -1,10 +1,12 @@ attrs>=21.1.0 defusedxml>=0.6.0 -GitPython>=3.0.8 lxml>=4.4.1 matplotlib>=3.3.1 +networkx>=2.5 numpy>=1.17.3 Pillow>=6.1.0 +ruamel.yaml>=0.17.0 +typing_extensions>=3.7.4.3 # Avoid 2.0.2 Linux binary distribution because of # a conflict in numpy versions with TensorFlow: diff --git a/requirements-default.txt b/requirements-default.txt new file mode 100644 index 0000000000..5b29e78320 --- /dev/null +++ b/requirements-default.txt @@ -0,0 +1,2 @@ +dvc>=2.3.0 +GitPython>=3.0.8 diff --git a/requirements.txt b/requirements.txt index fff66a6aa6..c2c6723271 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,15 @@ Cython>=0.27.3 # include before pycocotools -r requirements-core.txt --no-binary=pycocotools # https://github.com/openvinotoolkit/datumaro/issues/253 +-r requirements-default.txt + opencv-python-headless>=4.1.0.25 pandas>=1.1.5 + +# testing pytest>=5.3.5 # linters bandit>=1.7.0 isort~=5.9 pylint>=2.7.0 +coverage diff --git a/setup.py b/setup.py index 7a09f1cbb2..651252ed51 100644 --- a/setup.py +++ b/setup.py @@ -34,16 +34,20 @@ def find_version(project_dir=None): version = version_text[match.start(1) : match.end(1)] return version -def get_requirements(): - with open('requirements-core.txt') as fh: - requirements = [fh.read()] +CORE_REQUIREMENTS_FILE = 'requirements-core.txt' +DEFAULT_REQUIREMENTS_FILE = 'requirements-default.txt' - if strtobool(os.getenv('DATUMARO_HEADLESS', '0').lower()): - requirements.append('opencv-python-headless') - else: - requirements.append('opencv-python') +def parse_requirements(filename=CORE_REQUIREMENTS_FILE): + with open(filename) as fh: + return fh.readlines() - return requirements +CORE_REQUIREMENTS = parse_requirements(CORE_REQUIREMENTS_FILE) +if strtobool(os.getenv('DATUMARO_HEADLESS', '0').lower()): + CORE_REQUIREMENTS.append('opencv-python-headless') +else: + CORE_REQUIREMENTS.append('opencv-python') + +DEFAULT_REQUIREMENTS = parse_requirements(DEFAULT_REQUIREMENTS_FILE) with open('README.md', 'r') as fh: long_description = fh.read() @@ -61,17 +65,18 @@ def get_requirements(): long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/openvinotoolkit/datumaro", - packages=setuptools.find_packages(exclude=['tests*']), + packages=setuptools.find_packages(include=['datumaro']), classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], python_requires='>=3.6', - install_requires=get_requirements(), + install_requires=CORE_REQUIREMENTS, extras_require={ 'tf': ['tensorflow'], 'tf-gpu': ['tensorflow-gpu'], + 'default': DEFAULT_REQUIREMENTS, }, entry_points={ 'console_scripts': [ diff --git a/tests/assets/compat/v0.1/project/.datumaro/config.yaml b/tests/assets/compat/v0.1/project/.datumaro/config.yaml new file mode 100644 index 0000000000..9e478bc796 --- /dev/null +++ b/tests/assets/compat/v0.1/project/.datumaro/config.yaml @@ -0,0 +1,17 @@ +format_version: 1 +models: {} +project_name: undefined +subsets: [] +sources: + source1: + format: datumaro + options: {} + url: '' + source2: + format: datumaro + options: {} + url: annotations/test.json + source3: + format: my + options: {} + url: '' diff --git a/tests/assets/compat/v0.1/project/.datumaro/plugins/__init__.py b/tests/assets/compat/v0.1/project/.datumaro/plugins/__init__.py new file mode 100644 index 0000000000..d73f273b03 --- /dev/null +++ b/tests/assets/compat/v0.1/project/.datumaro/plugins/__init__.py @@ -0,0 +1,9 @@ +from datumaro.components.extractor import DatasetItem, SourceExtractor + + +class MyExtractor(SourceExtractor): + def __iter__(self): + yield from [ + DatasetItem('1'), + DatasetItem('2'), + ] diff --git a/tests/assets/compat/v0.1/project/dataset/annotations/train.json b/tests/assets/compat/v0.1/project/dataset/annotations/train.json new file mode 100644 index 0000000000..6125bb7250 --- /dev/null +++ b/tests/assets/compat/v0.1/project/dataset/annotations/train.json @@ -0,0 +1,44 @@ +{ + "info":{ + + }, + "categories":{ + "label":{ + "labels":[ + { + "name":"a", + "parent":"", + "attributes":[ + + ] + }, + { + "name":"b", + "parent":"", + "attributes":[ + + ] + } + ], + "attributes":[ + + ] + } + }, + "items":[ + { + "id":"2", + "annotations":[ + { + "id":0, + "type":"label", + "attributes":{ + + }, + "group":0, + "label_id":0 + } + ] + } + ] + } diff --git a/tests/assets/compat/v0.1/project/sources/source1/annotations/train.json b/tests/assets/compat/v0.1/project/sources/source1/annotations/train.json new file mode 100644 index 0000000000..01d31f9f83 --- /dev/null +++ b/tests/assets/compat/v0.1/project/sources/source1/annotations/train.json @@ -0,0 +1,44 @@ +{ + "info":{ + + }, + "categories":{ + "label":{ + "labels":[ + { + "name":"a", + "parent":"", + "attributes":[ + + ] + }, + { + "name":"b", + "parent":"", + "attributes":[ + + ] + } + ], + "attributes":[ + + ] + } + }, + "items":[ + { + "id":"0", + "annotations":[ + { + "id":0, + "type":"label", + "attributes":{ + + }, + "group":0, + "label_id":0 + } + ] + } + ] + } diff --git a/tests/assets/compat/v0.1/project/sources/source2/annotations/test.json b/tests/assets/compat/v0.1/project/sources/source2/annotations/test.json new file mode 100644 index 0000000000..1ef6e66e13 --- /dev/null +++ b/tests/assets/compat/v0.1/project/sources/source2/annotations/test.json @@ -0,0 +1,44 @@ +{ + "info":{ + + }, + "categories":{ + "label":{ + "labels":[ + { + "name":"a", + "parent":"", + "attributes":[ + + ] + }, + { + "name":"b", + "parent":"", + "attributes":[ + + ] + } + ], + "attributes":[ + + ] + } + }, + "items":[ + { + "id":"1", + "annotations":[ + { + "id":0, + "type":"label", + "attributes":{ + + }, + "group":0, + "label_id":1 + } + ] + } + ] + } diff --git a/tests/cli/test_diff.py b/tests/cli/test_diff.py index f1504ac3fd..04b1bbcac5 100644 --- a/tests/cli/test_diff.py +++ b/tests/cli/test_diff.py @@ -14,6 +14,7 @@ from datumaro.components.project import Dataset from datumaro.util.image import Image from datumaro.util.test_utils import TestDir +from datumaro.util.test_utils import run_datum as run from ..requirements import Requirements, mark_requirement @@ -123,3 +124,64 @@ def test_can_compare_projects(self): # just a smoke test visualizer.save(dataset1, dataset2) self.assertNotEqual(0, os.listdir(osp.join(test_dir))) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_run_distance_diff(self): + dataset1 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 4, label=0), + ]), + ], categories=['a', 'b']) + + dataset2 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 4, label=1), + Bbox(5, 6, 7, 8, label=2), + ]), + ], categories=['a', 'b', 'c']) + + with TestDir() as test_dir: + dataset1_url = osp.join(test_dir, 'dataset1') + dataset2_url = osp.join(test_dir, 'dataset2') + + dataset1.export(dataset1_url, 'coco', save_images=True) + dataset2.export(dataset2_url, 'voc', save_images=True) + + result_dir = osp.join(test_dir, 'cmp_result') + run(self, 'diff', dataset1_url + ':coco', dataset2_url + ':voc', + '-m', 'distance', '-o', result_dir) + + self.assertEqual({'bbox_confusion.png', 'train'}, + set(os.listdir(result_dir))) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_run_equality_diff(self): + dataset1 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 4, label=0), + ]), + ], categories=['a', 'b']) + + dataset2 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 4, label=1), + Bbox(5, 6, 7, 8, label=2), + ]), + ], categories=['a', 'b', 'c']) + + with TestDir() as test_dir: + dataset1_url = osp.join(test_dir, 'dataset1') + dataset2_url = osp.join(test_dir, 'dataset2') + + dataset1.export(dataset1_url, 'coco', save_images=True) + dataset2.export(dataset2_url, 'voc', save_images=True) + + result_dir = osp.join(test_dir, 'cmp_result') + run(self, 'diff', dataset1_url + ':coco', dataset2_url + ':voc', + '-m', 'equality', '-o', result_dir) + + self.assertEqual({'diff.json'}, set(os.listdir(result_dir))) diff --git a/tests/cli/test_image_zip_format.py b/tests/cli/test_image_zip_format.py index 4088a76d41..9b5ca0475a 100644 --- a/tests/cli/test_image_zip_format.py +++ b/tests/cli/test_image_zip_format.py @@ -32,13 +32,14 @@ def test_can_save_and_load(self): zip_path = osp.join(test_dir, 'images.zip') make_zip_archive(test_dir, zip_path) - run(self, 'create', '-o', test_dir) - run(self, 'add', 'path', '-p', test_dir, '-f', 'image_zip', zip_path) - - export_path = osp.join(test_dir, 'export.zip') - run(self, 'export', '-p', test_dir, '-f', 'image_zip', - '-o', test_dir, '--overwrite', '--', - '--name', osp.basename(export_path) + proj_dir = osp.join(test_dir, 'proj') + run(self, 'create', '-o', proj_dir) + run(self, 'add', '-p', proj_dir, '-f', 'image_zip', zip_path) + + result_dir = osp.join(test_dir, 'result') + export_path = osp.join(result_dir, 'export.zip') + run(self, 'export', '-p', proj_dir, '-f', 'image_zip', + '-o', result_dir, '--', '--name', osp.basename(export_path) ) parsed_dataset = Dataset.import_from(export_path, format='image_zip') @@ -51,7 +52,7 @@ def test_can_export_zip_images_from_coco_dataset(self): 'tests', 'assets', 'coco_dataset') run(self, 'create', '-o', test_dir) - run(self, 'add', 'path', '-p', test_dir, '-f', 'coco', coco_dir) + run(self, 'add', '-p', test_dir, '-f', 'coco', coco_dir) export_path = osp.join(test_dir, 'export.zip') run(self, 'export', '-p', test_dir, '-f', 'image_zip', @@ -75,11 +76,12 @@ def test_can_change_extension_for_images_in_zip(self): zip_path = osp.join(test_dir, 'images.zip') make_zip_archive(test_dir, zip_path) - run(self, 'create', '-o', test_dir) - run(self, 'add', 'path', '-p', test_dir, '-f', 'image_zip', zip_path) + proj_dir = osp.join(test_dir, 'proj') + run(self, 'create', '-o', proj_dir) + run(self, 'add', '-p', proj_dir, '-f', 'image_zip', zip_path) export_path = osp.join(test_dir, 'export.zip') - run(self, 'export', '-p', test_dir, '-f', 'image_zip', + run(self, 'export', '-p', proj_dir, '-f', 'image_zip', '-o', test_dir, '--overwrite', '--', '--name', osp.basename(export_path), '--image-ext', '.png') diff --git a/tests/cli/test_merge.py b/tests/cli/test_merge.py new file mode 100644 index 0000000000..5e5ca62538 --- /dev/null +++ b/tests/cli/test_merge.py @@ -0,0 +1,119 @@ +from unittest import TestCase +import os.path as osp + +import numpy as np + +from datumaro.components.annotation import ( + AnnotationType, Bbox, LabelCategories, MaskCategories, +) +from datumaro.components.extractor import DatasetItem +from datumaro.components.project import Dataset, Project +from datumaro.util.test_utils import TestDir, compare_datasets +from datumaro.util.test_utils import run_datum as run +import datumaro.plugins.voc_format.format as VOC + +from ..requirements import Requirements, mark_requirement + + +class MergeTest(TestCase): + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_run_self_merge(self): + dataset1 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 3, label=0), + ]), + ], categories=['a', 'b']) + + dataset2 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 4, label=1), + Bbox(5, 6, 2, 3, label=2), + ]), + ], categories=['a', 'b', 'c']) + + expected = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 4, label=2, id=1, group=1, + attributes={'score': 0.5, 'occluded': False, + 'difficult': False, 'truncated': False}), + Bbox(5, 6, 2, 3, label=3, id=2, group=2, + attributes={'score': 0.5, 'occluded': False, + 'difficult': False, 'truncated': False}), + Bbox(1, 2, 3, 3, label=1, id=1, group=1, + attributes={'score': 0.5, 'is_crowd': False}), + ]), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + ['background', 'a', 'b', 'c']), + AnnotationType.mask: MaskCategories(VOC.generate_colormap(4)) + }) + + with TestDir() as test_dir: + dataset1_url = osp.join(test_dir, 'dataset1') + dataset2_url = osp.join(test_dir, 'dataset2') + + dataset1.export(dataset1_url, 'coco', save_images=True) + dataset2.export(dataset2_url, 'voc', save_images=True) + + proj_dir = osp.join(test_dir, 'proj') + with Project.init(proj_dir) as project: + project.import_source('source', dataset2_url, 'voc') + + result_dir = osp.join(test_dir, 'cmp_result') + run(self, 'merge', dataset1_url + ':coco', '-o', result_dir, + '-p', proj_dir) + + compare_datasets(self, expected, Dataset.load(result_dir), + require_images=True) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_run_multimerge(self): + dataset1 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 3, label=0), + ]), + ], categories=['a', 'b']) + + dataset2 = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 4, label=1), + Bbox(5, 6, 2, 3, label=2), + ]), + ], categories=['a', 'b', 'c']) + + expected = Dataset.from_iterable([ + DatasetItem(id=100, subset='train', image=np.ones((10, 6, 3)), + annotations=[ + Bbox(1, 2, 3, 4, label=2, id=1, group=1, + attributes={'score': 0.5, 'occluded': False, + 'difficult': False, 'truncated': False}), + Bbox(5, 6, 2, 3, label=3, id=2, group=2, + attributes={'score': 0.5, 'occluded': False, + 'difficult': False, 'truncated': False}), + Bbox(1, 2, 3, 3, label=1, id=1, group=1, + attributes={'score': 0.5, 'is_crowd': False}), + ]), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + ['background', 'a', 'b', 'c']), + AnnotationType.mask: MaskCategories(VOC.generate_colormap(4)) + }) + + with TestDir() as test_dir: + dataset1_url = osp.join(test_dir, 'dataset1') + dataset2_url = osp.join(test_dir, 'dataset2') + + dataset1.export(dataset1_url, 'coco', save_images=True) + dataset2.export(dataset2_url, 'voc', save_images=True) + + result_dir = osp.join(test_dir, 'cmp_result') + run(self, 'merge', dataset2_url + ':voc', dataset1_url + ':coco', + '-o', result_dir) + + compare_datasets(self, expected, Dataset.load(result_dir), + require_images=True) diff --git a/tests/cli/test_project.py b/tests/cli/test_project.py new file mode 100644 index 0000000000..357bedbecf --- /dev/null +++ b/tests/cli/test_project.py @@ -0,0 +1,167 @@ +from unittest import TestCase +import os.path as osp +import shutil + +import numpy as np + +from datumaro.components.annotation import Bbox, Label +from datumaro.components.dataset import DEFAULT_FORMAT, Dataset +from datumaro.components.extractor import DatasetItem +from datumaro.components.project import Project +from datumaro.util.test_utils import TestDir, compare_datasets +from datumaro.util.test_utils import run_datum as run + +from ..requirements import Requirements, mark_requirement + + +class ProjectIntegrationScenarios(TestCase): + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_convert_voc_as_coco(self): + voc_dir = osp.join(__file__[:__file__.rfind(osp.join('tests', ''))], + 'tests', 'assets', 'voc_dataset', 'voc_dataset1') + + with TestDir() as test_dir: + result_dir = osp.join(test_dir, 'coco_export') + + run(self, 'convert', + '-if', 'voc', '-i', voc_dir, + '-f', 'coco', '-o', result_dir, + '--', '--save-images', '--reindex', '1') + + self.assertTrue(osp.isdir(result_dir)) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_export_coco_as_voc(self): + # TODO: use subformats once importers are removed + coco_dir = osp.join(__file__[:__file__.rfind(osp.join('tests', ''))], + 'tests', 'assets', 'coco_dataset', 'coco_instances') + + with TestDir() as test_dir: + run(self, 'create', '-o', test_dir) + run(self, 'add', '-f', 'coco', '-p', test_dir, coco_dir) + + result_dir = osp.join(test_dir, 'voc_export') + run(self, 'export', '-f', 'voc', '-p', test_dir, '-o', result_dir, + '--', '--save-images') + + self.assertTrue(osp.isdir(result_dir)) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_list_project_info(self): + coco_dir = osp.join(__file__[:__file__.rfind(osp.join('tests', ''))], + 'tests', 'assets', 'coco_dataset', 'coco_instances') + + with TestDir() as test_dir: + run(self, 'create', '-o', test_dir) + run(self, 'add', '-f', 'coco', '-p', test_dir, coco_dir) + + with self.subTest("on project"): + run(self, 'project', 'info', '-p', test_dir) + + with self.subTest("on project revision"): + run(self, 'project', 'info', '-p', test_dir, 'HEAD') + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_list_dataset_info(self): + coco_dir = osp.join(__file__[:__file__.rfind(osp.join('tests', ''))], + 'tests', 'assets', 'coco_dataset', 'coco_instances') + + with TestDir() as test_dir: + run(self, 'create', '-o', test_dir) + run(self, 'add', '-f', 'coco', '-p', test_dir, coco_dir) + run(self, 'commit', '-m', 'first', '-p', test_dir) + + with self.subTest("on current project"): + run(self, 'info', '-p', test_dir) + + with self.subTest("on current project revision"): + run(self, 'info', '-p', test_dir, 'HEAD') + + with self.subTest("on other project"): + run(self, 'info', test_dir) + + with self.subTest("on other project revision"): + run(self, 'info', test_dir + '@HEAD') + + with self.subTest("on dataset"): + run(self, 'info', coco_dir + ':coco') + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_use_vcs(self): + with TestDir() as test_dir: + dataset_dir = osp.join(test_dir, 'dataset') + project_dir = osp.join(test_dir, 'proj') + result_dir = osp.join(project_dir, 'result') + + Dataset.from_iterable([ + DatasetItem(0, image=np.ones((1, 2, 3)), annotations=[ + Bbox(1, 1, 1, 1, label=0), + Bbox(2, 2, 2, 2, label=1), + ]) + ], categories=['a', 'b']).save(dataset_dir, save_images=True) + + run(self, 'create', '-o', project_dir) + run(self, 'add', '-p', project_dir, '-f', 'datumaro', dataset_dir) + run(self, 'commit', '-p', project_dir, '-m', 'Add data') + + run(self, 'transform', '-p', project_dir, + '-t', 'remap_labels', 'source-1', '--', '-l', 'b:cat') + run(self, 'commit', '-p', project_dir, '-m', 'Add transform') + + run(self, 'filter', '-p', project_dir, + '-e', '/item/annotation[label="cat"]', '-m', 'i+a') + run(self, 'commit', '-p', project_dir, '-m', 'Add filter') + + run(self, 'export', '-p', project_dir, '-f', 'coco', + '-o', result_dir, 'source-1', '--', '--save-images') + parsed = Dataset.import_from(result_dir, 'coco') + compare_datasets(self, Dataset.from_iterable([ + DatasetItem(0, image=np.ones((1, 2, 3)), + annotations=[ + Bbox(2, 2, 2, 2, label=1, + group=1, id=1, attributes={'is_crowd': False}), + ], attributes={ 'id': 1 }) + ], categories=['a', 'cat']), parsed, require_images=True) + + shutil.rmtree(result_dir, ignore_errors=True) + run(self, 'checkout', '-p', project_dir, 'HEAD~1') + run(self, 'export', '-p', project_dir, '-f', 'coco', + '-o', result_dir, '--', '--save-images') + parsed = Dataset.import_from(result_dir, 'coco') + compare_datasets(self, Dataset.from_iterable([ + DatasetItem(0, image=np.ones((1, 2, 3)), annotations=[ + Bbox(1, 1, 1, 1, label=0, + group=1, id=1, attributes={'is_crowd': False}), + Bbox(2, 2, 2, 2, label=1, + group=2, id=2, attributes={'is_crowd': False}), + ], attributes={ 'id': 1 }) + ], categories=['a', 'cat']), parsed, require_images=True) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_chain_transforms_in_working_tree(self): + with TestDir() as test_dir: + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project_dir = osp.join(test_dir, 'proj') + run(self, 'create', '-o', project_dir) + run(self, 'add', '-p', project_dir, + '--format', DEFAULT_FORMAT, source_url) + run(self, 'filter', '-p', project_dir, + '-e', '/item/annotation[label="b"]') + run(self, 'transform', '-p', project_dir, + '-t', 'rename', '--', '-e', '|2|qq|') + run(self, 'transform', '-p', project_dir, + '-t', 'remap_labels', '--', '-l', 'a:cat', '-l', 'b:dog') + + with Project(project_dir) as project: + built_dataset = project.working_tree.make_dataset() + + expected_dataset = Dataset.from_iterable([ + DatasetItem('qq', annotations=[Label(1)]), + ], categories=['cat', 'dog']) + compare_datasets(self, expected_dataset, built_dataset) diff --git a/tests/cli/test_revpath.py b/tests/cli/test_revpath.py new file mode 100644 index 0000000000..fd5541367d --- /dev/null +++ b/tests/cli/test_revpath.py @@ -0,0 +1,140 @@ +from unittest.case import TestCase +import os.path as osp + +from datumaro.cli.util.project import ( + WrongRevpathError, parse_full_revpath, split_local_revpath, +) +from datumaro.components.dataset import DEFAULT_FORMAT, Dataset, IDataset +from datumaro.components.errors import ( + MultipleFormatsMatchError, ProjectNotFoundError, UnknownTargetError, +) +from datumaro.components.extractor import DatasetItem +from datumaro.components.project import Project +from datumaro.util.scope import scope_add, scoped +from datumaro.util.test_utils import TestDir + +from ..requirements import Requirements, mark_requirement + + +class TestRevpath(TestCase): + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_parse(self): + test_dir = scope_add(TestDir()) + + dataset_url = osp.join(test_dir, 'source') + Dataset.from_iterable([DatasetItem(1)]).save(dataset_url) + + proj_dir = osp.join(test_dir, 'proj') + proj = scope_add(Project.init(proj_dir)) + proj.import_source('source-1', dataset_url, format=DEFAULT_FORMAT) + ref = proj.commit("second commit", allow_empty=True) + + with self.subTest("project"): + dataset, project = parse_full_revpath(proj_dir) + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertTrue(isinstance(project, Project)) + + with self.subTest("project ref"): + dataset, project = parse_full_revpath(f"{proj_dir}@{ref}") + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertTrue(isinstance(project, Project)) + + with self.subTest("project ref source"): + dataset, project = parse_full_revpath(f"{proj_dir}@{ref}:source-1") + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertTrue(isinstance(project, Project)) + + with self.subTest("project ref source stage"): + dataset, project = parse_full_revpath( + f"{proj_dir}@{ref}:source-1.root") + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertTrue(isinstance(project, Project)) + + with self.subTest("ref"): + dataset, project = parse_full_revpath(ref, proj) + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("ref source"): + dataset, project = parse_full_revpath(f"{ref}:source-1", proj) + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("ref source stage"): + dataset, project = parse_full_revpath(f"{ref}:source-1.root", proj) + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("source"): + dataset, project = parse_full_revpath("source-1", proj) + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("source stage"): + dataset, project = parse_full_revpath("source-1.root", proj) + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("dataset (in context)"): + with self.assertRaises(WrongRevpathError) as cm: + parse_full_revpath(dataset_url, proj) + self.assertEqual( + {UnknownTargetError, MultipleFormatsMatchError}, + set(type(e) for e in cm.exception.problems) + ) + + with self.subTest("dataset format (in context)"): + dataset, project = parse_full_revpath( + f"{dataset_url}:datumaro", proj) + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + with self.subTest("dataset (no context)"): + with self.assertRaises(WrongRevpathError) as cm: + parse_full_revpath(dataset_url) + self.assertEqual( + {ProjectNotFoundError, MultipleFormatsMatchError}, + set(type(e) for e in cm.exception.problems) + ) + + with self.subTest("dataset format (no context)"): + dataset, project = parse_full_revpath(f"{dataset_url}:datumaro") + if project: + scope_add(project) + self.assertTrue(isinstance(dataset, IDataset)) + self.assertEqual(None, project) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_split_local_revpath(self): + with self.subTest("full"): + self.assertEqual(("rev", "tgt"), split_local_revpath("rev:tgt")) + + with self.subTest("rev only"): + self.assertEqual(("rev", ""), split_local_revpath("rev:")) + + with self.subTest("build target only"): + self.assertEqual(("", "tgt"), split_local_revpath("tgt")) + + with self.subTest("build target only (empty rev)"): + self.assertEqual(("", "tgt"), split_local_revpath(":tgt")) diff --git a/tests/cli/test_voc_format.py b/tests/cli/test_voc_format.py index 725564f0f7..a4dd011024 100644 --- a/tests/cli/test_voc_format.py +++ b/tests/cli/test_voc_format.py @@ -19,8 +19,12 @@ class VocIntegrationScenarios(TestCase): def _test_can_save_and_load(self, project_path, source_path, expected_dataset, dataset_format, result_path='', label_map=None): run(self, 'create', '-o', project_path) - run(self, 'add', 'path', '-p', project_path, '-f', dataset_format, - source_path) + + extra_args = [] + if result_path: + extra_args += ['-r', result_path] + run(self, 'add', '-p', project_path, '-f', dataset_format, + *extra_args, source_path) result_dir = osp.join(project_path, 'result') extra_args = ['--', '--save-images'] @@ -83,19 +87,17 @@ def test_preparing_dataset_for_train_model(self): with TestDir() as test_dir: run(self, 'create', '-o', test_dir) - run(self, 'add', 'path', '-p', test_dir, '-f', 'voc', dataset_path) + run(self, 'add', '-p', test_dir, '-f', 'voc', dataset_path) - result_path = osp.join(test_dir, 'result') run(self, 'filter', '-p', test_dir, '-m', 'i+a', - '-e', "/item/annotation[occluded='False']", '-o', result_path) + '-e', "/item/annotation[occluded='False']") - split_path = osp.join(test_dir, 'split') - run(self, 'transform', '-p', result_path, '-o', split_path, + run(self, 'transform', '-p', test_dir, '-t', 'random_split', '--', '-s', 'test:.5', '-s', 'train:.5', '--seed', '1') export_path = osp.join(test_dir, 'dataset') - run(self, 'export', '-p', split_path, '-f', 'voc', + run(self, 'export', '-p', test_dir, '-f', 'voc', '-o', export_path, '--', '--label-map', 'voc') parsed_dataset = Dataset.import_from(export_path, format='voc') @@ -135,7 +137,7 @@ def test_export_to_voc_format(self): 'tests', 'assets', 'yolo_dataset') run(self, 'create', '-o', test_dir) - run(self, 'add', 'path', '-p', test_dir, '-f', 'yolo', yolo_dir) + run(self, 'add', '-p', test_dir, '-f', 'yolo', yolo_dir) voc_export = osp.join(test_dir, 'voc_export') run(self, 'export', '-p', test_dir, '-f', 'voc', @@ -271,7 +273,7 @@ def test_can_save_and_load_voc_dataset(self): @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_save_and_load_voc_layout_dataset(self): - source_dataset = Dataset.from_iterable([ + expected_dataset = Dataset.from_iterable([ DatasetItem(id='2007_000001', subset='train', image=np.ones((10, 20, 3)), annotations=[ @@ -303,18 +305,17 @@ def test_can_save_and_load_voc_layout_dataset(self): for format, subset, path in matrix: with self.subTest(format=format, subset=subset, path=path): if subset: - source = source_dataset.get_subset(subset) + expected = expected_dataset.get_subset(subset) else: - source = source_dataset + expected = expected_dataset with TestDir() as test_dir: - self._test_can_save_and_load(test_dir, - osp.join(dataset_dir, path), source, - format, result_path=path, label_map='voc') + self._test_can_save_and_load(test_dir, dataset_dir, + expected, format, result_path=path, label_map='voc') @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_save_and_load_voc_classification_dataset(self): - source_dataset = Dataset.from_iterable([ + expected_dataset = Dataset.from_iterable([ DatasetItem(id='2007_000001', subset='train', image=np.ones((10, 20, 3)), annotations=[Label(i) for i in range(22) if i % 2 == 1]), @@ -332,18 +333,17 @@ def test_can_save_and_load_voc_classification_dataset(self): for format, subset, path in matrix: with self.subTest(format=format, subset=subset, path=path): if subset: - source = source_dataset.get_subset(subset) + expected = expected_dataset.get_subset(subset) else: - source = source_dataset + expected = expected_dataset with TestDir() as test_dir: - self._test_can_save_and_load(test_dir, - osp.join(dataset_dir, path), source, - format, result_path=path, label_map='voc') + self._test_can_save_and_load(test_dir, dataset_dir, + expected, format, result_path=path, label_map='voc') @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_save_and_load_voc_detection_dataset(self): - source_dataset = Dataset.from_iterable([ + expected_dataset = Dataset.from_iterable([ DatasetItem(id='2007_000001', subset='train', image=np.ones((10, 20, 3)), annotations=[ @@ -381,18 +381,17 @@ def test_can_save_and_load_voc_detection_dataset(self): for format, subset, path in matrix: with self.subTest(format=format, subset=subset, path=path): if subset: - source = source_dataset.get_subset(subset) + expected = expected_dataset.get_subset(subset) else: - source = source_dataset + expected = expected_dataset with TestDir() as test_dir: - self._test_can_save_and_load(test_dir, - osp.join(dataset_dir, path), source, - format, result_path=path, label_map='voc') + self._test_can_save_and_load(test_dir, dataset_dir, + expected, format, result_path=path, label_map='voc') @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_save_and_load_voc_segmentation_dataset(self): - source_dataset = Dataset.from_iterable([ + expected_dataset = Dataset.from_iterable([ DatasetItem(id='2007_000001', subset='train', image=np.ones((10, 20, 3)), annotations=[ @@ -413,14 +412,13 @@ def test_can_save_and_load_voc_segmentation_dataset(self): for format, subset, path in matrix: with self.subTest(format=format, subset=subset, path=path): if subset: - source = source_dataset.get_subset(subset) + expected = expected_dataset.get_subset(subset) else: - source = source_dataset + expected = expected_dataset with TestDir() as test_dir: - self._test_can_save_and_load(test_dir, - osp.join(dataset_dir, path), source, - format, result_path=path, label_map='voc') + self._test_can_save_and_load(test_dir, dataset_dir, + expected, format, result_path=path, label_map='voc') @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_save_and_load_voc_action_dataset(self): @@ -460,6 +458,5 @@ def test_can_save_and_load_voc_action_dataset(self): expected = expected_dataset with TestDir() as test_dir: - self._test_can_save_and_load(test_dir, - osp.join(dataset_dir, path), expected, - format, result_path=path, label_map='voc') + self._test_can_save_and_load(test_dir, dataset_dir, + expected, format, result_path=path, label_map='voc') diff --git a/tests/cli/test_yolo_format.py b/tests/cli/test_yolo_format.py index ebf850fc91..c2b00c49aa 100644 --- a/tests/cli/test_yolo_format.py +++ b/tests/cli/test_yolo_format.py @@ -30,7 +30,8 @@ def test_can_save_and_load_yolo_dataset(self): yolo_dir = osp.join(__file__[:__file__.rfind(osp.join('tests', ''))], 'tests', 'assets', 'yolo_dataset') - run(self, 'import', '-o', test_dir, '-f', 'yolo', '-i', yolo_dir) + run(self, 'create', '-o', test_dir) + run(self, 'add', '-p', test_dir, '-f', 'yolo', yolo_dir) export_dir = osp.join(test_dir, 'export_dir') run(self, 'export', '-p', test_dir, '-o', export_dir, @@ -54,7 +55,7 @@ def test_can_export_mot_as_yolo(self): 'tests', 'assets', 'mot_dataset') run(self, 'create', '-o', test_dir) - run(self, 'add', 'path', '-p', test_dir, '-f', 'mot_seq', mot_dir) + run(self, 'add', '-p', test_dir, '-f', 'mot_seq', mot_dir) yolo_dir = osp.join(test_dir, 'yolo_dir') run(self, 'export', '-p', test_dir, '-o', yolo_dir, @@ -119,11 +120,12 @@ def test_can_ignore_non_supported_subsets(self): dataset_dir = osp.join(test_dir, 'dataset_dir') source_dataset.save(dataset_dir, save_images=True) - run(self, 'create', '-o', test_dir) - run(self, 'add', 'path', '-p', test_dir, '-f', 'datumaro', dataset_dir) + proj_dir = osp.join(test_dir, 'proj') + run(self, 'create', '-o', proj_dir) + run(self, 'add', '-p', proj_dir, '-f', 'datumaro', dataset_dir) yolo_dir = osp.join(test_dir, 'yolo_dir') - run(self, 'export', '-p', test_dir, '-o', yolo_dir, + run(self, 'export', '-p', proj_dir, '-o', yolo_dir, '-f', 'yolo', '--', '--save-images') parsed_dataset = Dataset.import_from(yolo_dir, format='yolo') @@ -145,19 +147,17 @@ def test_can_delete_labels_from_yolo_dataset(self): 'tests', 'assets', 'yolo_dataset') run(self, 'create', '-o', test_dir) - run(self, 'add', 'path', '-p', test_dir, '-f', 'yolo', yolo_dir) + run(self, 'add', '-p', test_dir, '-f', 'yolo', yolo_dir) - filtered_path = osp.join(test_dir, 'filtered') - run(self, 'filter', '-p', test_dir, '-o', filtered_path, + run(self, 'filter', '-p', test_dir, '-m', 'i+a', '-e', "/item/annotation[label='label_2']") - result_path = osp.join(test_dir, 'result') - run(self, 'transform', '-p', filtered_path, '-o', result_path, + run(self, 'transform', '-p', test_dir, '-t', 'remap_labels', '--', '-l', 'label_2:label_2', '--default', 'delete') export_dir = osp.join(test_dir, 'export') - run(self, 'export', '-p', result_path, '-o', export_dir, + run(self, 'export', '-p', test_dir, '-o', export_dir, '-f', 'yolo', '--', '--save-image') parsed_dataset = Dataset.import_from(export_dir, format='yolo') diff --git a/tests/test_command_targets.py b/tests/test_command_targets.py deleted file mode 100644 index 81c7c2c374..0000000000 --- a/tests/test_command_targets.py +++ /dev/null @@ -1,143 +0,0 @@ -from unittest import TestCase -import os.path as osp - -import numpy as np - -from datumaro.components.project import Project -from datumaro.util.command_targets import ( - ImageTarget, ProjectTarget, SourceTarget, -) -from datumaro.util.image import save_image -from datumaro.util.test_utils import TestDir - -from .requirements import Requirements, mark_requirement - - -class CommandTargetsTest(TestCase): - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_image_false_when_no_file(self): - target = ImageTarget() - - status = target.test('somepath.jpg') - - self.assertFalse(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_image_false_when_false(self): - with TestDir() as test_dir: - path = osp.join(test_dir, 'test.jpg') - with open(path, 'w+') as f: - f.write('qwerty123') - - target = ImageTarget() - - status = target.test(path) - - self.assertFalse(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_image_true_when_true(self): - with TestDir() as test_dir: - path = osp.join(test_dir, 'test.jpg') - save_image(path, np.ones([10, 7, 3])) - - target = ImageTarget() - - status = target.test(path) - - self.assertTrue(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_false_when_no_file(self): - target = ProjectTarget() - - status = target.test('somepath.jpg') - - self.assertFalse(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_false_when_no_name(self): - target = ProjectTarget(project=Project()) - - status = target.test('') - - self.assertFalse(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_true_when_project_file(self): - with TestDir() as test_dir: - path = osp.join(test_dir, 'test.jpg') - Project().save(path) - - target = ProjectTarget() - - status = target.test(path) - - self.assertTrue(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_true_when_project_name(self): - project_name = 'qwerty' - project = Project({ - 'project_name': project_name - }) - target = ProjectTarget(project=project) - - status = target.test(project_name) - - self.assertTrue(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_false_when_not_project_name(self): - project_name = 'qwerty' - project = Project({ - 'project_name': project_name - }) - target = ProjectTarget(project=project) - - status = target.test(project_name + '123') - - self.assertFalse(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_false_when_not_project_file(self): - with TestDir() as test_dir: - path = osp.join(test_dir, 'test.jpg') - with open(path, 'w+') as f: - f.write('wqererw') - - target = ProjectTarget() - - status = target.test(path) - - self.assertFalse(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_source_false_when_no_project(self): - target = SourceTarget() - - status = target.test('qwerty123') - - self.assertFalse(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_source_true_when_source_exists(self): - source_name = 'qwerty' - project = Project() - project.add_source(source_name) - target = SourceTarget(project=project) - - status = target.test(source_name) - - self.assertTrue(status) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_source_false_when_source_doesnt_exist(self): - source_name = 'qwerty' - project = Project() - project.add_source(source_name) - target = SourceTarget(project=project) - - status = target.test(source_name + '123') - - self.assertFalse(status) diff --git a/tests/test_project.py b/tests/test_project.py index 59222c8975..57ba6e166c 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,437 +1,876 @@ from unittest import TestCase import os import os.path as osp +import shutil import textwrap import numpy as np -from datumaro.components.annotation import ( - AnnotationType, Label, LabelCategories, -) -from datumaro.components.config import Config +from datumaro.components.annotation import Bbox, Label from datumaro.components.config_model import Model, Source from datumaro.components.dataset import DEFAULT_FORMAT, Dataset -from datumaro.components.extractor import DatasetItem, Extractor -from datumaro.components.launcher import Launcher, ModelTransform -from datumaro.components.project import Environment, Project -from datumaro.util.test_utils import TestDir, compare_datasets +from datumaro.components.errors import ( + DatasetMergeError, EmptyCommitError, ForeignChangesError, + MismatchingObjectError, MissingObjectError, OldProjectError, + PathOutsideSourceError, ReadonlyProjectError, SourceExistsError, + SourceUrlInsideProjectError, +) +from datumaro.components.extractor import DatasetItem, Extractor, ItemTransform +from datumaro.components.launcher import Launcher +from datumaro.components.project import DiffStatus, Project +from datumaro.util.scope import scope_add, scoped +from datumaro.util.test_utils import TestDir, compare_datasets, compare_dirs from .requirements import Requirements, mark_requirement class ProjectTest(TestCase): @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_generate(self): - src_config = Config({ - 'project_name': 'test_project', - 'format_version': 1, - }) + @scoped + def test_can_init_and_load(self): + test_dir = scope_add(TestDir()) - with TestDir() as test_dir: - project_path = test_dir - Project.generate(project_path, src_config) + scope_add(Project.init(test_dir)).close() + scope_add(Project(test_dir)) - self.assertTrue(osp.isdir(project_path)) + self.assertTrue('.datumaro' in os.listdir(test_dir)) - result_config = Project.load(project_path).config - self.assertEqual( - src_config.project_name, result_config.project_name) - self.assertEqual( - src_config.format_version, result_config.format_version) - - @staticmethod @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_default_ctor_is_ok(): - Project() + @scoped + def test_can_find_project_in_project_dir(self): + test_dir = scope_add(TestDir()) + + scope_add(Project.init(test_dir)) + + self.assertEqual(osp.join(test_dir, '.datumaro'), + Project.find_project_dir(test_dir)) - @staticmethod @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_empty_config_is_ok(): - Project(Config()) + @scoped + def test_cant_find_project_when_no_project(self): + test_dir = scope_add(TestDir()) + + self.assertEqual(None, Project.find_project_dir(test_dir)) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_add_source(self): + @scoped + def test_can_add_local_model(self): + class TestLauncher(Launcher): + pass + source_name = 'source' - origin = Source({ - 'url': 'path', - 'format': 'ext' + config = Model({ + 'launcher': 'test', + 'options': { 'a': 5, 'b': 'hello' } }) - project = Project() - project.add_source(source_name, origin) + test_dir = scope_add(TestDir()) + project = scope_add(Project.init(test_dir)) + project.env.launchers.register('test', TestLauncher) - added = project.get_source(source_name) - self.assertIsNotNone(added) - self.assertEqual(added, origin) + project.add_model(source_name, + launcher=config.launcher, options=config.options) + + added = project.models[source_name] + self.assertEqual(added.launcher, config.launcher) + self.assertEqual(added.options, config.options) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_added_source_can_be_saved(self): - source_name = 'source' - origin = Source({ - 'url': 'path', - }) - project = Project() - project.add_source(source_name, origin) + @scoped + def test_can_run_inference(self): + class TestLauncher(Launcher): + def launch(self, inputs): + for inp in inputs: + yield [ Label(inp[0, 0, 0]) ] + + expected = Dataset.from_iterable([ + DatasetItem(0, image=np.zeros([2, 2, 3]), annotations=[Label(0)]), + DatasetItem(1, image=np.ones([2, 2, 3]), annotations=[Label(1)]) + ], categories=['a', 'b']) + + launcher_name = 'custom_launcher' + model_name = 'model' + + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones([2, 2, 3]) * 0), + DatasetItem(1, image=np.ones([2, 2, 3]) * 1), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.env.launchers.register(launcher_name, TestLauncher) + project.add_model(model_name, launcher=launcher_name) + project.import_source('source', source_url, format=DEFAULT_FORMAT) - saved = project.config + dataset = project.working_tree.make_dataset() + model = project.make_model(model_name) - self.assertEqual(origin, saved.sources[source_name]) + inference = dataset.run_model(model) + + compare_datasets(self, expected, inference) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_added_source_can_be_dumped(self): + @scoped + def test_can_import_local_source(self): + test_dir = scope_add(TestDir()) + source_base_url = osp.join(test_dir, 'test_repo') + source_file_path = osp.join(source_base_url, 'x', 'y.txt') + os.makedirs(osp.dirname(source_file_path), exist_ok=True) + with open(source_file_path, 'w') as f: + f.write('hello') + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_base_url, format='fmt') + + source = project.working_tree.sources['s1'] + self.assertEqual('fmt', source.format) + compare_dirs(self, source_base_url, project.source_data_dir('s1')) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertTrue('/s1' in [line.strip() for line in f]) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_import_local_source_with_relpath(self): + # This form must copy all the data in URL, but read only + # specified files. Required to support subtasks and subsets. + + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, subset='a', image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='b', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) + + expected_dataset = Dataset.from_iterable([ + DatasetItem(1, subset='b', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT, + rpath=osp.join('annotations', 'b.json')) + + source = project.working_tree.sources['s1'] + self.assertEqual(DEFAULT_FORMAT, source.format) + + compare_dirs(self, source_url, project.source_data_dir('s1')) + read_dataset = project.working_tree.make_dataset('s1') + compare_datasets(self, expected_dataset, read_dataset, + require_images=True) + + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertTrue('/s1' in [line.strip() for line in f]) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_cant_import_local_source_with_relpath_outside(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + os.makedirs(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + + with self.assertRaises(PathOutsideSourceError): + project.import_source('s1', url=source_url, + format=DEFAULT_FORMAT, rpath='..') + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_cant_import_local_source_with_url_inside_project(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'qq') + with open(source_url, 'w') as f: + f.write('hello') + + project = scope_add(Project.init(test_dir)) + + with self.assertRaises(SourceUrlInsideProjectError): + project.import_source('s1', url=source_url, + format=DEFAULT_FORMAT) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_report_incompatible_sources(self): + test_dir = scope_add(TestDir()) + source1_url = osp.join(test_dir, 'dataset1') + dataset1 = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + dataset1.save(source1_url) + + source2_url = osp.join(test_dir, 'dataset2') + dataset2 = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['c', 'd']) + dataset2.save(source2_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source1_url, format=DEFAULT_FORMAT) + project.import_source('s2', url=source2_url, format=DEFAULT_FORMAT) + + with self.assertRaises(DatasetMergeError) as cm: + project.working_tree.make_dataset() + + self.assertEqual({'s1.root', 's2.root'}, cm.exception.sources) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_cant_add_sources_with_same_names(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + + with self.assertRaises(SourceExistsError): + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_import_generated_source(self): + test_dir = scope_add(TestDir()) source_name = 'source' origin = Source({ - 'url': 'path', + # no url + 'format': 'fmt', + 'options': { 'c': 5, 'd': 'hello' } }) - project = Project() - project.add_source(source_name, origin) + project = scope_add(Project.init(test_dir)) - with TestDir() as test_dir: - project.save(test_dir) + project.import_source(source_name, url='', + format=origin.format, options=origin.options) - loaded = Project.load(test_dir) - loaded = loaded.get_source(source_name) - self.assertEqual(origin, loaded) + added = project.working_tree.sources[source_name] + self.assertEqual(added.format, origin.format) + self.assertEqual(added.options, origin.options) + with open(osp.join(test_dir, '.gitignore')) as f: + self.assertTrue('/' + source_name in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_import_with_custom_importer(self): - class TestImporter: - def __call__(self, path, subset=None): - return Project({ - 'project_filename': path, - 'subsets': [ subset ] - }) + @scoped + def test_cant_import_source_with_wrong_name(self): + test_dir = scope_add(TestDir()) + project = scope_add(Project.init(test_dir)) - path = 'path' - importer_name = 'test_importer' + for name in {'dataset', 'project', 'build', '.any'}: + with self.subTest(name=name), \ + self.assertRaisesRegex(ValueError, "Source name"): + project.import_source(name, url='', format='fmt') + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_remove_source_and_keep_data(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') - env = Environment() - env.importers.register(importer_name, TestImporter) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project = Project.import_from(path, importer_name, env, - subset='train') + project.remove_source('s1', keep_data=True) - self.assertEqual(path, project.config.project_filename) - self.assertListEqual(['train'], project.config.subsets) + self.assertFalse('s1' in project.working_tree.sources) + compare_dirs(self, source_url, project.source_data_dir('s1')) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertFalse('/s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_dump_added_model(self): - model_name = 'model' + @scoped + def test_can_remove_source_and_wipe_data(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') - project = Project() - saved = Model({ 'launcher': 'name' }) - project.add_model(model_name, saved) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - with TestDir() as test_dir: - project.save(test_dir) + project.remove_source('s1', keep_data=False) - loaded = Project.load(test_dir) - loaded = loaded.get_model(model_name) - self.assertEqual(saved, loaded) + self.assertFalse('s1' in project.working_tree.sources) + self.assertFalse(osp.exists(project.source_data_dir('s1'))) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertFalse('/s1' in [line.strip() for line in f]) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_have_project_source(self): - with TestDir() as test_dir: - Project.generate(test_dir) + @scoped + def test_can_redownload_source_rev_noncached(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") - project2 = Project() - project2.add_source('project1', { - 'url': test_dir, - }) - dataset = project2.make_dataset() + # remove local source data + project.remove_cache_obj( + project.working_tree.build_targets['s1'].head.hash) + shutil.rmtree(project.source_data_dir('s1')) - self.assertTrue('project1' in dataset.sources) + read_dataset = project.working_tree.make_dataset('s1') + + compare_datasets(self, source_dataset, read_dataset) + compare_dirs(self, source_url, project.cache_path( + project.working_tree.build_targets['s1'].root.hash)) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_batch_launch_custom_model(self): - dataset = Dataset.from_iterable([ - DatasetItem(id=i, subset='train', image=np.array([i])) - for i in range(5) - ], categories=['label']) + @scoped + def test_can_redownload_source_and_check_data_hash(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - class TestLauncher(Launcher): - def launch(self, inputs): - for i, inp in enumerate(inputs): - yield [ Label(0, attributes={'idx': i, 'data': inp.item()}) ] + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") - model_name = 'model' - launcher_name = 'custom_launcher' + # remove local source data + project.remove_cache_obj( + project.working_tree.build_targets['s1'].head.hash) + shutil.rmtree(project.source_data_dir('s1')) - project = Project() - project.env.launchers.register(launcher_name, TestLauncher) - project.add_model(model_name, { 'launcher': launcher_name }) - model = project.make_executable_model(model_name) + # modify the source repo + with open(osp.join(source_url, 'extra_file.txt'), 'w') as f: + f.write('text\n') + + with self.assertRaises(MismatchingObjectError): + project.working_tree.make_dataset('s1') - batch_size = 3 - executor = ModelTransform(dataset, model, batch_size=batch_size) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_use_source_from_cache_with_working_copy(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - for item in executor: - self.assertEqual(1, len(item.annotations)) - self.assertEqual(int(item.id) % batch_size, - item.annotations[0].attributes['idx']) - self.assertEqual(int(item.id), - item.annotations[0].attributes['data']) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") + + shutil.rmtree(project.source_data_dir('s1')) + + read_dataset = project.working_tree.make_dataset('s1') + + compare_datasets(self, source_dataset, read_dataset) + self.assertFalse(osp.isdir(project.source_data_dir('s1'))) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_do_transform_with_custom_model(self): - class TestExtractorSrc(Extractor): - def __iter__(self): - for i in range(2): - yield DatasetItem(id=i, image=np.ones([2, 2, 3]) * i, - annotations=[Label(i)]) + @scoped + def test_raises_an_error_if_local_data_unknown(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.zeros((10, 20, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - def categories(self): - label_cat = LabelCategories() - label_cat.add('0') - label_cat.add('1') - return { AnnotationType.label: label_cat } + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") - class TestLauncher(Launcher): - def launch(self, inputs): - for inp in inputs: - yield [ Label(inp[0, 0, 0]) ] + # remove the cached object so that it couldn't be matched + project.remove_cache_obj( + project.working_tree.build_targets['s1'].root.hash) - class TestExtractorDst(Extractor): - def __init__(self, url): - super().__init__() - self.items = [osp.join(url, p) for p in sorted(os.listdir(url))] + # modify local source data + with open(osp.join(project.source_data_dir('s1'), 'extra.txt'), + 'w') as f: + f.write('text\n') - def __iter__(self): - for path in self.items: - with open(path, 'r') as f: - index = osp.splitext(osp.basename(path))[0] - label = int(f.readline().strip()) - yield DatasetItem(id=index, annotations=[Label(label)]) + with self.assertRaises(ForeignChangesError): + project.working_tree.make_dataset('s1') - model_name = 'model' - launcher_name = 'custom_launcher' - extractor_name = 'custom_extractor' + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_read_working_copy_of_source(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.ones((1, 2, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - project = Project() - project.env.launchers.register(launcher_name, TestLauncher) - project.env.extractors.register(extractor_name, TestExtractorSrc) - project.add_model(model_name, { 'launcher': launcher_name }) - project.add_source('source', { 'format': extractor_name }) - - with TestDir() as test_dir: - project.make_dataset().apply_model(model=model_name, - save_dir=test_dir) - - result = Project.load(test_dir) - result.env.extractors.register(extractor_name, TestExtractorDst) - it = iter(result.make_dataset()) - item1 = next(it) - item2 = next(it) - self.assertEqual(0, item1.annotations[0].label) - self.assertEqual(1, item2.annotations[0].label) - - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_source_datasets_can_be_merged(self): - class TestExtractor(Extractor): - def __init__(self, url, n=0, s=0): - super().__init__(length=n) - self.n = n - self.s = s + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - def __iter__(self): - for i in range(self.n): - yield DatasetItem(id=self.s + i, subset='train') + read_dataset = project.working_tree.make_dataset('s1') + + compare_datasets(self, source_dataset, read_dataset) + compare_dirs(self, source_url, project.source_data_dir('s1')) - e_name1 = 'e1' - e_name2 = 'e2' - n1 = 2 - n2 = 4 + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_read_current_revision_of_source(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'source') + source_dataset = Dataset.from_iterable([ + DatasetItem(0, image=np.ones((2, 3, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=0) ]), + DatasetItem(1, subset='s', image=np.ones((1, 2, 3)), + annotations=[ Bbox(1, 2, 3, 4, label=1) ]), + ], categories=['a', 'b']) + source_dataset.save(source_url, save_images=True) - project = Project() - project.env.extractors.register(e_name1, lambda p: TestExtractor(p, n=n1)) - project.env.extractors.register(e_name2, lambda p: TestExtractor(p, n=n2, s=n1)) - project.add_source('source1', { 'format': e_name1 }) - project.add_source('source2', { 'format': e_name2 }) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("A commit") - dataset = project.make_dataset() + shutil.rmtree(project.source_data_dir('s1')) - self.assertEqual(n1 + n2, len(dataset)) + read_dataset = project.head.make_dataset('s1') + + compare_datasets(self, source_dataset, read_dataset) + self.assertFalse(osp.isdir(project.source_data_dir('s1'))) + compare_dirs(self, source_url, project.head.source_data_dir('s1')) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_cant_merge_different_categories(self): - class TestExtractor1(Extractor): - def __iter__(self): - return iter([]) + @scoped + def test_can_make_dataset_from_project(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + source_dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + source_dataset.save(source_url) - def categories(self): - return { AnnotationType.label: - LabelCategories.from_iterable(['a', 'b']) } + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - class TestExtractor2(Extractor): - def __iter__(self): - return iter([]) + read_dataset = project.working_tree.make_dataset() - def categories(self): - return { AnnotationType.label: - LabelCategories.from_iterable(['b', 'a']) } + compare_datasets(self, source_dataset, read_dataset) - e_name1 = 'e1' - e_name2 = 'e2' + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_make_dataset_from_source(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - project = Project() - project.env.extractors.register(e_name1, TestExtractor1) - project.env.extractors.register(e_name2, TestExtractor2) - project.add_source('source1', { 'format': e_name1 }) - project.add_source('source2', { 'format': e_name2 }) + built_dataset = project.working_tree.make_dataset('s1') - with self.assertRaisesRegex(Exception, "different categories"): - project.make_dataset() + compare_datasets(self, dataset, built_dataset) + self.assertEqual(DEFAULT_FORMAT, built_dataset.format) + self.assertEqual(project.source_data_dir('s1'), + built_dataset.data_path) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_filter_can_be_applied(self): - class TestExtractor(Extractor): - def __iter__(self): - for i in range(10): - yield DatasetItem(id=i, subset='train') + @scoped + def test_can_add_filter_stage(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - e_type = 'type' - project = Project() - project.env.extractors.register(e_type, TestExtractor) - project.add_source('source', { 'format': e_type }) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - dataset = project.make_dataset().filter('/item[id < 5]') + stage = project.working_tree.build_targets.add_filter_stage('s1', + '/item/annotation[label="b"]' + ) - self.assertEqual(5, len(dataset)) + self.assertTrue(stage in project.working_tree.build_targets) + resulting_dataset = project.working_tree.make_dataset('s1') + compare_datasets(self, Dataset.from_iterable([ + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']), resulting_dataset) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_save_and_load_own_dataset(self): - with TestDir() as test_dir: - src_project = Project() - src_dataset = src_project.make_dataset() - item = DatasetItem(id=1) - src_dataset.put(item) - src_dataset.save(test_dir) + @scoped + def test_can_add_convert_stage(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) - loaded_project = Project.load(test_dir) - loaded_dataset = loaded_project.make_dataset() + stage = project.working_tree.build_targets.add_convert_stage('s1', + DEFAULT_FORMAT) - self.assertEqual(list(src_dataset), list(loaded_dataset)) + self.assertTrue(stage in project.working_tree.build_targets) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_add_transform_stage(self): + class TestTransform(ItemTransform): + def __init__(self, extractor, p1=None, p2=None): + super().__init__(extractor) + self.p1 = p1 + self.p2 = p2 + + def transform_item(self, item): + return self.wrap_item(item, + attributes={'p1': self.p1, 'p2': self.p2}) + + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.working_tree.env.transforms.register('tr', TestTransform) + + stage = project.working_tree.build_targets.add_transform_stage('s1', + 'tr', params={'p1': 5, 'p2': ['1', 2, 3.5]} + ) + + self.assertTrue(stage in project.working_tree.build_targets) + resulting_dataset = project.working_tree.make_dataset('s1') + compare_datasets(self, Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)], + attributes={'p1': 5, 'p2': ['1', 2, 3.5]}), + DatasetItem(2, annotations=[Label(1)], + attributes={'p1': 5, 'p2': ['1', 2, 3.5]}), + ], categories=['a', 'b']), resulting_dataset) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_own_dataset_can_be_modified(self): - project = Project() - dataset = project.make_dataset() + @scoped + def test_can_make_dataset_from_stage(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + stage = project.working_tree.build_targets.add_filter_stage('s1', + '/item/annotation[label="b"]') - item = DatasetItem(id=1) - dataset.put(item) + built_dataset = project.working_tree.make_dataset(stage) - self.assertEqual(item, next(iter(dataset))) + expected_dataset = Dataset.from_iterable([ + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + compare_datasets(self, expected_dataset, built_dataset) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_compound_child_can_be_modified_recursively(self): - with TestDir() as test_dir: - child1 = Project({ - 'project_dir': osp.join(test_dir, 'child1'), - }) - child1.save() + @scoped + def test_can_commit(self): + test_dir = scope_add(TestDir()) + project = scope_add(Project.init(test_dir)) - child2 = Project({ - 'project_dir': osp.join(test_dir, 'child2'), - }) - child2.save() + commit_hash = project.commit("First commit", allow_empty=True) - parent = Project() - parent.add_source('child1', { - 'url': child1.config.project_dir - }) - parent.add_source('child2', { - 'url': child2.config.project_dir - }) - dataset = parent.make_dataset() + self.assertTrue(project.is_ref(commit_hash)) + self.assertEqual(len(project.history()), 2) + self.assertEqual(project.history()[0], + (commit_hash, "First commit")) - item1 = DatasetItem(id='ch1', path=['child1']) - item2 = DatasetItem(id='ch2', path=['child2']) - dataset.put(item1) - dataset.put(item2) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_cant_commit_empty(self): + test_dir = scope_add(TestDir()) + project = scope_add(Project.init(test_dir)) - self.assertEqual(2, len(dataset)) - self.assertEqual(1, len(dataset.sources['child1'])) - self.assertEqual(1, len(dataset.sources['child2'])) + with self.assertRaises(EmptyCommitError): + project.commit("First commit") @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_project_can_merge_item_annotations(self): - class TestExtractor1(Extractor): - def __iter__(self): - yield DatasetItem(id=1, subset='train', annotations=[ - Label(2, id=3), - Label(3, attributes={ 'x': 1 }), - ]) + @scoped + def test_can_commit_patch(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', source_url, format=DEFAULT_FORMAT) + project.commit("First commit") + + source_path = osp.join( + project.source_data_dir('s1'), + osp.basename(source_url)) + with open(source_path, 'w') as f: + f.write('world') + + commit_hash = project.commit("Second commit", allow_foreign=True) + + self.assertTrue(project.is_ref(commit_hash)) + self.assertNotEqual( + project.get_rev('HEAD~1').build_targets['s1'].head.hash, + project.working_tree.build_targets['s1'].head.hash) + self.assertTrue(project.is_obj_cached( + project.working_tree.build_targets['s1'].head.hash)) - class TestExtractor2(Extractor): - def __iter__(self): - yield DatasetItem(id=1, subset='train', annotations=[ - Label(3, attributes={ 'x': 1 }), - Label(4, id=4), - ]) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_cant_commit_foreign_changes(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', source_url, format=DEFAULT_FORMAT) + project.commit("First commit") + + source_path = osp.join( + project.source_data_dir('s1'), + osp.basename(source_url)) + with open(source_path, 'w') as f: + f.write('world') + + with self.assertRaises(ForeignChangesError): + project.commit("Second commit") - project = Project() - project.env.extractors.register('t1', TestExtractor1) - project.env.extractors.register('t2', TestExtractor2) - project.add_source('source1', { 'format': 't1' }) - project.add_source('source2', { 'format': 't2' }) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_checkout_revision(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_source.txt') + os.makedirs(osp.dirname(source_url), exist_ok=True) + with open(source_url, 'w') as f: + f.write('hello') + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', source_url, format=DEFAULT_FORMAT) + project.commit("First commit") + + source_path = osp.join( + project.source_data_dir('s1'), + osp.basename(source_url)) + with open(source_path, 'w') as f: + f.write('world') + project.commit("Second commit", allow_foreign=True) + + project.checkout('HEAD~1') + + compare_dirs(self, source_url, project.source_data_dir('s1')) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + self.assertTrue('/s1' in [line.strip() for line in f]) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_checkout_sources(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s2', url=source_url, format=DEFAULT_FORMAT) + project.commit("Commit 1") + project.remove_source('s1', keep_data=False) # remove s1 from tree + shutil.rmtree(project.source_data_dir('s2')) # modify s2 "manually" + + project.checkout(sources=['s1', 's2']) + + compare_dirs(self, source_url, project.source_data_dir('s1')) + compare_dirs(self, source_url, project.source_data_dir('s2')) + with open(osp.join(test_dir, 'proj', '.gitignore')) as f: + lines = [line.strip() for line in f] + self.assertTrue('/s1' in lines) + self.assertTrue('/s2' in lines) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_checkout_sources_from_revision(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - merged = project.make_dataset() + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.commit("Commit 1") + project.remove_source('s1', keep_data=False) + project.commit("Commit 2") - self.assertEqual(1, len(merged)) + project.checkout(rev='HEAD~1', sources=['s1']) - item = next(iter(merged)) - self.assertEqual(3, len(item.annotations)) + compare_dirs(self, source_url, project.source_data_dir('s1')) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_can_detect_and_import(self): - env = Environment() - env.importers.items = {DEFAULT_FORMAT: env.importers[DEFAULT_FORMAT]} - env.extractors.items = {DEFAULT_FORMAT: env.extractors[DEFAULT_FORMAT]} + @scoped + def test_can_check_status(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) + + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s2', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s3', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s4', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s5', url=source_url, format=DEFAULT_FORMAT) + project.commit("Commit 1") + + project.remove_source('s2') + project.import_source('s6', url=source_url, format=DEFAULT_FORMAT) + + shutil.rmtree(project.source_data_dir('s3')) + + project.working_tree.build_targets \ + .add_transform_stage('s4', 'reindex') + project.working_tree.make_dataset('s4').save() + project.refresh_source_hash('s4') + + s5_dir = osp.join(project.source_data_dir('s5')) + with open(osp.join(s5_dir, 'annotations', 't.txt'), 'w') as f: + f.write("hello") + + status = project.status() + self.assertEqual({ + 's2': DiffStatus.removed, + 's3': DiffStatus.missing, + 's4': DiffStatus.modified, + 's5': DiffStatus.foreign_modified, + 's6': DiffStatus.added, + }, status) - source_dataset = Dataset.from_iterable([ - DatasetItem(id=1, annotations=[ Label(2) ]), - ], categories=['a', 'b', 'c']) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_compare_revisions(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - with TestDir() as test_dir: - source_dataset.save(test_dir) + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + project.import_source('s2', url=source_url, format=DEFAULT_FORMAT) + rev1 = project.commit("Commit 1") - project = Project.import_from(test_dir, env=env) - imported_dataset = project.make_dataset() + project.remove_source('s2') + project.import_source('s3', url=source_url, format=DEFAULT_FORMAT) + rev2 = project.commit("Commit 2") - self.assertEqual(next(iter(project.config.sources.values())).format, - DEFAULT_FORMAT) - compare_datasets(self, source_dataset, imported_dataset) + diff = project.diff(rev1, rev2) + self.assertEqual(diff, + { 's2': DiffStatus.removed, 's3': DiffStatus.added }) @mark_requirement(Requirements.DATUM_GENERAL_REQ) - def test_custom_extractor_can_be_created(self): - class CustomExtractor(Extractor): - def __iter__(self): - return iter([ - DatasetItem(id=0, subset='train'), - DatasetItem(id=1, subset='train'), - DatasetItem(id=2, subset='train'), + @scoped + def test_can_restore_revision(self): + test_dir = scope_add(TestDir()) + source_url = osp.join(test_dir, 'test_repo') + dataset = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=['a', 'b']) + dataset.save(source_url) - DatasetItem(id=3, subset='test'), - DatasetItem(id=4, subset='test'), + project = scope_add(Project.init(osp.join(test_dir, 'proj'))) + project.import_source('s1', url=source_url, format=DEFAULT_FORMAT) + rev1 = project.commit("Commit 1") - DatasetItem(id=1), - DatasetItem(id=2), - DatasetItem(id=3), - ]) + project.remove_cache_obj(rev1) - extractor_name = 'ext1' - project = Project() - project.env.extractors.register(extractor_name, CustomExtractor) - project.add_source('src1', { - 'url': 'path', - 'format': extractor_name, - }) + self.assertFalse(project.is_rev_cached(rev1)) - dataset = project.make_dataset() + head_dataset = project.head.make_dataset() - compare_datasets(self, CustomExtractor(), dataset) + self.assertTrue(project.is_rev_cached(rev1)) + compare_datasets(self, dataset, head_dataset) + + @mark_requirement(Requirements.DATUM_BUG_404) + @scoped + def test_can_add_plugin(self): + test_dir = scope_add(TestDir()) + scope_add(Project.init(test_dir)).close() + + plugin_dir = osp.join(test_dir, '.datumaro', 'plugins') + os.makedirs(plugin_dir) + with open(osp.join(plugin_dir, '__init__.py'), 'w') as f: + f.write(textwrap.dedent(""" + from datumaro.components.extractor import (SourceExtractor, + DatasetItem) + + class MyExtractor(SourceExtractor): + def __iter__(self): + yield from [ + DatasetItem('1'), + DatasetItem('2'), + ] + """)) + + project = scope_add(Project(test_dir)) + project.import_source('src', url='', format='my') + + expected = Dataset.from_iterable([ + DatasetItem('1'), + DatasetItem('2') + ]) + compare_datasets(self, expected, project.working_tree.make_dataset()) @mark_requirement(Requirements.DATUM_BUG_402) + @scoped def test_can_transform_by_name(self): class CustomExtractor(Extractor): def __iter__(self): @@ -440,14 +879,12 @@ def __iter__(self): DatasetItem('b'), ]) + test_dir = scope_add(TestDir()) extractor_name = 'ext1' - project = Project() + project = scope_add(Project.init(test_dir)) project.env.extractors.register(extractor_name, CustomExtractor) - project.add_source('src1', { - 'url': '', - 'format': extractor_name, - }) - dataset = project.make_dataset() + project.import_source('src1', url='', format=extractor_name) + dataset = project.working_tree.make_dataset() dataset = dataset.transform('reindex') @@ -457,34 +894,87 @@ def __iter__(self): ]) compare_datasets(self, expected, dataset) - @mark_requirement(Requirements.DATUM_BUG_404) - def test_can_add_plugin(self): - with TestDir() as test_dir: - Project.generate(test_dir) - - plugin_dir = osp.join(test_dir, '.datumaro', 'plugins') - os.makedirs(plugin_dir) - with open(osp.join(plugin_dir, '__init__.py'), 'w') as f: - f.write(textwrap.dedent(""" - from datumaro.components.extractor import (SourceExtractor, - DatasetItem) - - class MyExtractor(SourceExtractor): - def __iter__(self): - yield from [ - DatasetItem('1'), - DatasetItem('2'), - ] - """)) - - project = Project.load(test_dir) - project.add_source('src', { - 'url': '', - 'format': 'my' - }) - - expected = Dataset.from_iterable([ - DatasetItem('1'), - DatasetItem('2') - ]) - compare_datasets(self, expected, project.make_dataset()) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_cant_modify_readonly(self): + test_dir = scope_add(TestDir()) + dataset_url = osp.join(test_dir, 'dataset') + Dataset.from_iterable([ + DatasetItem('a'), + DatasetItem('b'), + ]).save(dataset_url) + + proj_dir = osp.join(test_dir, 'proj') + with Project.init(proj_dir) as project: + project.import_source('source1', url=dataset_url, + format=DEFAULT_FORMAT) + project.commit('first commit') + project.remove_source('source1') + commit2 = project.commit('second commit') + project.checkout('HEAD~1') + project.remove_cache_obj(commit2) + project.remove_cache_obj( + project.working_tree.sources['source1'].hash) + + project = scope_add(Project(proj_dir, readonly=True)) + + self.assertTrue(project.readonly) + + with self.subTest("add source"), self.assertRaises(ReadonlyProjectError): + project.import_source('src1', url='', format=DEFAULT_FORMAT) + + with self.subTest("remove source"), self.assertRaises(ReadonlyProjectError): + project.remove_source('src1') + + with self.subTest("add model"), self.assertRaises(ReadonlyProjectError): + project.add_model('m1', launcher='x') + + with self.subTest("remove model"), self.assertRaises(ReadonlyProjectError): + project.remove_model('m1') + + with self.subTest("checkout"), self.assertRaises(ReadonlyProjectError): + project.checkout('HEAD') + + with self.subTest("commit"), self.assertRaises(ReadonlyProjectError): + project.commit('third commit', allow_empty=True) + + # Can't re-download the source in a readonly project + with self.subTest("make_dataset"), self.assertRaises(MissingObjectError): + project.get_rev('HEAD').make_dataset() + +class BackwardCompatibilityTests_v0_1(TestCase): + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_can_migrate_old_project(self): + expected_dataset = Dataset.from_iterable([ + DatasetItem(0, subset='train', annotations=[Label(0)]), + DatasetItem(1, subset='test', annotations=[Label(1)]), + DatasetItem(2, subset='train', annotations=[Label(0)]), + DatasetItem(1), + DatasetItem(2), + ], categories=['a', 'b']) + + test_dir = scope_add(TestDir()) + old_proj_dir = osp.join(test_dir, 'old_proj') + new_proj_dir = osp.join(test_dir, 'new_proj') + shutil.copytree(osp.join(osp.dirname(__file__), + 'assets', 'compat', 'v0.1', 'project'), + old_proj_dir) + + Project.migrate_from_v1_to_v2(old_proj_dir, new_proj_dir) + + project = scope_add(Project(new_proj_dir)) + loaded_dataset = project.working_tree.make_dataset() + compare_datasets(self, expected_dataset, loaded_dataset) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + @scoped + def test_cant_load_old_project(self): + test_dir = scope_add(TestDir()) + proj_dir = osp.join(test_dir, 'old_proj') + shutil.copytree(osp.join(osp.dirname(__file__), + 'assets', 'compat', 'v0.1', 'project'), + proj_dir) + + with self.assertRaises(OldProjectError): + scope_add(Project(proj_dir))