diff --git a/cvat-sdk/cvat_sdk/auto_annotation/__init__.py b/cvat-sdk/cvat_sdk/auto_annotation/__init__.py index 8d4a22cbd661..03dccb9f805e 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/__init__.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/__init__.py @@ -2,7 +2,8 @@ # # SPDX-License-Identifier: MIT -from .driver import BadFunctionError, annotate_task +from .driver import annotate_task +from .exceptions import BadFunctionError from .interface import ( DetectionFunction, DetectionFunctionContext, diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py index 2b5627dcf3a9..05af519abdac 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/driver.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py @@ -2,27 +2,24 @@ # # SPDX-License-Identifier: MIT +from __future__ import annotations + import logging from collections.abc import Mapping, Sequence from typing import Optional import attrs +from typing_extensions import TypeAlias import cvat_sdk.models as models from cvat_sdk.core import Client from cvat_sdk.core.progress import NullProgressReporter, ProgressReporter from cvat_sdk.datasets.task_dataset import TaskDataset +from .exceptions import BadFunctionError from .interface import DetectionFunction, DetectionFunctionContext, DetectionFunctionSpec -class BadFunctionError(Exception): - """ - An exception that signifies that an auto-detection function has violated some constraint - set by its interface. - """ - - @attrs.frozen class _SublabelNameMapping: name: str @@ -53,66 +50,35 @@ def map_label(self, name: str): class _AnnotationMapper: + _SublabelIdMapping: TypeAlias = int + @attrs.frozen class _LabelIdMapping: id: int - sublabels: Mapping[int, Optional[int]] + sublabels: Mapping[int, Optional[_AnnotationMapper._SublabelIdMapping]] expected_num_elements: int = 0 - _label_id_mappings: Mapping[int, Optional[_LabelIdMapping]] + _SpecIdMapping: TypeAlias = Mapping[int, Optional[_LabelIdMapping]] + + _spec_id_mapping: _SpecIdMapping def _build_label_id_mapping( self, fun_label: models.ILabel, - ds_labels_by_name: Mapping[str, models.ILabel], + ds_label: models.ILabel, *, + label_nm: _LabelNameMapping, allow_unmatched_labels: bool, - spec_nm: _SpecNameMapping, ) -> Optional[_LabelIdMapping]: - if getattr(fun_label, "attributes", None): - raise BadFunctionError(f"label attributes are currently not supported") - - label_nm = spec_nm.map_label(fun_label.name) - if label_nm is None: - return None - - ds_label = ds_labels_by_name.get(label_nm.name) - if ds_label is None: - if not allow_unmatched_labels: - raise BadFunctionError(f"label {fun_label.name!r} is not in dataset") - - self._logger.info( - "label %r is not in dataset; any annotations using it will be ignored", - fun_label.name, - ) - return None - sl_map = {} if getattr(fun_label, "sublabels", []): - fun_label_type = getattr(fun_label, "type", "any") - if fun_label_type != "skeleton": - raise BadFunctionError( - f"label {fun_label.name!r} with sublabels has type {fun_label_type!r} (should be 'skeleton')" - ) - ds_sublabels_by_name = {ds_sl.name: ds_sl for ds_sl in ds_label.sublabels} - for fun_sl in fun_label.sublabels: - if not hasattr(fun_sl, "id"): - raise BadFunctionError( - f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has no ID" - ) - - if fun_sl.id in sl_map: - raise BadFunctionError( - f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has same ID as another sublabel ({fun_sl.id})" - ) - + def sublabel_mapping(fun_sl: models.ILabel) -> Optional[int]: sublabel_nm = label_nm.map_sublabel(fun_sl.name) if sublabel_nm is None: - sl_map[fun_sl.id] = None - continue + return None ds_sl = ds_sublabels_by_name.get(sublabel_nm.name) if not ds_sl: @@ -126,151 +92,182 @@ def _build_label_id_mapping( fun_sl.name, fun_label.name, ) - sl_map[fun_sl.id] = None - continue + return None + + return ds_sl.id - sl_map[fun_sl.id] = ds_sl.id + sl_map = {fun_sl.id: sublabel_mapping(fun_sl) for fun_sl in fun_label.sublabels} return self._LabelIdMapping( ds_label.id, sublabels=sl_map, expected_num_elements=len(ds_label.sublabels) ) - def __init__( + def _build_spec_id_mapping( self, - logger: logging.Logger, fun_labels: Sequence[models.ILabel], ds_labels: Sequence[models.ILabel], *, + spec_nm: _SpecNameMapping, allow_unmatched_labels: bool, - conv_mask_to_poly: bool, - spec_nm: _SpecNameMapping = _SpecNameMapping(), - ) -> None: - self._logger = logger - self._conv_mask_to_poly = conv_mask_to_poly - + ) -> _SpecIdMapping: ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels} - self._label_id_mappings = {} + def label_id_mapping(fun_label: models.ILabel) -> Optional[self._LabelIdMapping]: + label_nm = spec_nm.map_label(fun_label.name) + if label_nm is None: + return None - for fun_label in fun_labels: - if not hasattr(fun_label, "id"): - raise BadFunctionError(f"label {fun_label.name!r} has no ID") + ds_label = ds_labels_by_name.get(label_nm.name) + if ds_label is None: + if not allow_unmatched_labels: + raise BadFunctionError(f"label {fun_label.name!r} is not in dataset") - if fun_label.id in self._label_id_mappings: - raise BadFunctionError( - f"label {fun_label.name} has same ID as another label ({fun_label.id})" + self._logger.info( + "label %r is not in dataset; any annotations using it will be ignored", + fun_label.name, ) + return None - self._label_id_mappings[fun_label.id] = self._build_label_id_mapping( + return self._build_label_id_mapping( fun_label, - ds_labels_by_name, + ds_label, + label_nm=label_nm, allow_unmatched_labels=allow_unmatched_labels, - spec_nm=spec_nm, ) - def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: int) -> None: - new_shapes = [] + return {fun_label.id: label_id_mapping(fun_label) for fun_label in fun_labels} + + def __init__( + self, + logger: logging.Logger, + fun_labels: Sequence[models.ILabel], + ds_labels: Sequence[models.ILabel], + *, + allow_unmatched_labels: bool, + conv_mask_to_poly: bool, + spec_nm: _SpecNameMapping = _SpecNameMapping(), + ) -> None: + self._logger = logger + self._conv_mask_to_poly = conv_mask_to_poly - for shape in shapes: - if hasattr(shape, "id"): - raise BadFunctionError("function output shape with preset id") + self._spec_id_mapping = self._build_spec_id_mapping( + fun_labels, ds_labels, spec_nm=spec_nm, allow_unmatched_labels=allow_unmatched_labels + ) - if hasattr(shape, "source"): - raise BadFunctionError("function output shape with preset source") - shape.source = "auto" + def _remap_element( + self, + element: models.SubLabeledShapeRequest, + ds_frame: int, + label_id_mapping: _LabelIdMapping, + seen_sl_ids: set[int], + ) -> bool: + if hasattr(element, "id"): + raise BadFunctionError("function output shape element with preset id") + + if hasattr(element, "source"): + raise BadFunctionError("function output shape element with preset source") + element.source = "auto" + + if element.frame != 0: + raise BadFunctionError( + f"function output shape element with unexpected frame number ({element.frame})" + ) - if shape.frame != 0: - raise BadFunctionError( - f"function output shape with unexpected frame number ({shape.frame})" - ) + element.frame = ds_frame - shape.frame = ds_frame + if element.type.value != "points": + raise BadFunctionError( + f"function output skeleton with element type other than 'points' ({element.type.value})" + ) - try: - label_id_mapping = self._label_id_mappings[shape.label_id] - except KeyError: - raise BadFunctionError( - f"function output shape with unknown label ID ({shape.label_id})" - ) + try: + mapped_sl_id = label_id_mapping.sublabels[element.label_id] + except KeyError: + raise BadFunctionError( + f"function output shape with unknown sublabel ID ({element.label_id})" + ) - if not label_id_mapping: - continue + if not mapped_sl_id: + return False - shape.label_id = label_id_mapping.id + if mapped_sl_id in seen_sl_ids: + raise BadFunctionError( + "function output skeleton with multiple elements with same sublabel" + ) - if getattr(shape, "attributes", None): - raise BadFunctionError( - "function output shape with attributes, which is not yet supported" - ) + element.label_id = mapped_sl_id - new_shapes.append(shape) + seen_sl_ids.add(mapped_sl_id) - if shape.type.value == "skeleton": - new_elements = [] - seen_sl_ids = set() + return True - for element in shape.elements: - if hasattr(element, "id"): - raise BadFunctionError("function output shape element with preset id") + def _remap_elements( + self, shape: models.LabeledShapeRequest, ds_frame: int, label_id_mapping: _LabelIdMapping + ) -> None: + if shape.type.value == "skeleton": + seen_sl_ids = set() - if hasattr(element, "source"): - raise BadFunctionError("function output shape element with preset source") - element.source = "auto" + shape.elements[:] = [ + element + for element in shape.elements + if self._remap_element(element, ds_frame, label_id_mapping, seen_sl_ids) + ] - if element.frame != 0: - raise BadFunctionError( - f"function output shape element with unexpected frame number ({element.frame})" - ) + if len(shape.elements) != label_id_mapping.expected_num_elements: + # There could only be fewer elements than expected, + # because the reverse would imply that there are more distinct sublabel IDs + # than are actually defined in the dataset. + assert len(shape.elements) < label_id_mapping.expected_num_elements - element.frame = ds_frame + raise BadFunctionError( + "function output skeleton with fewer elements than expected" + f" ({len(shape.elements)} vs {label_id_mapping.expected_num_elements})" + ) + else: + if getattr(shape, "elements", None): + raise BadFunctionError("function output non-skeleton shape with elements") - if element.type.value != "points": - raise BadFunctionError( - f"function output skeleton with element type other than 'points' ({element.type.value})" - ) + def _remap_shape(self, shape: models.LabeledShapeRequest, ds_frame: int) -> bool: + if hasattr(shape, "id"): + raise BadFunctionError("function output shape with preset id") - try: - mapped_sl_id = label_id_mapping.sublabels[element.label_id] - except KeyError: - raise BadFunctionError( - f"function output shape with unknown sublabel ID ({element.label_id})" - ) + if hasattr(shape, "source"): + raise BadFunctionError("function output shape with preset source") + shape.source = "auto" - if not mapped_sl_id: - continue + if shape.frame != 0: + raise BadFunctionError( + f"function output shape with unexpected frame number ({shape.frame})" + ) - if mapped_sl_id in seen_sl_ids: - raise BadFunctionError( - "function output skeleton with multiple elements with same sublabel" - ) + shape.frame = ds_frame - element.label_id = mapped_sl_id + try: + label_id_mapping = self._spec_id_mapping[shape.label_id] + except KeyError: + raise BadFunctionError( + f"function output shape with unknown label ID ({shape.label_id})" + ) - seen_sl_ids.add(mapped_sl_id) + if not label_id_mapping: + return False - new_elements.append(element) + shape.label_id = label_id_mapping.id - if len(new_elements) != label_id_mapping.expected_num_elements: - # new_elements could only be shorter than expected, - # because the reverse would imply that there are more distinct sublabel IDs - # than are actually defined in the dataset. - assert len(new_elements) < label_id_mapping.expected_num_elements + if shape.type.value == "mask" and self._conv_mask_to_poly: + raise BadFunctionError("function output mask shape despite conv_mask_to_poly=True") - raise BadFunctionError( - f"function output skeleton with fewer elements than expected ({len(new_elements)} vs {label_id_mapping.expected_num_elements})" - ) + if getattr(shape, "attributes", None): + raise BadFunctionError( + "function output shape with attributes, which is not yet supported" + ) - shape.elements[:] = new_elements - else: - if getattr(shape, "elements", None): - raise BadFunctionError("function output non-skeleton shape with elements") + self._remap_elements(shape, ds_frame, label_id_mapping) - if shape.type.value == "mask" and self._conv_mask_to_poly: - raise BadFunctionError( - "function output mask shape despite conv_mask_to_poly=True" - ) + return True - shapes[:] = new_shapes + def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: int) -> None: + shapes[:] = [shape for shape in shapes if self._remap_shape(shape, ds_frame)] @attrs.frozen(kw_only=True) diff --git a/cvat-sdk/cvat_sdk/auto_annotation/exceptions.py b/cvat-sdk/cvat_sdk/auto_annotation/exceptions.py new file mode 100644 index 000000000000..99717584b854 --- /dev/null +++ b/cvat-sdk/cvat_sdk/auto_annotation/exceptions.py @@ -0,0 +1,10 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + + +class BadFunctionError(Exception): + """ + An exception that signifies that an auto-detection function has violated some constraint + set by its interface. + """ diff --git a/cvat-sdk/cvat_sdk/auto_annotation/interface.py b/cvat-sdk/cvat_sdk/auto_annotation/interface.py index 28275b83b346..0c9cf445114b 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/interface.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/interface.py @@ -11,14 +11,19 @@ import cvat_sdk.models as models +from .exceptions import BadFunctionError + @attrs.frozen(kw_only=True) class DetectionFunctionSpec: """ Static information about an auto-annotation detection function. + + Objects of this class should be treated as immutable; + do not modify them or any nested objects after they are created. """ - labels: Sequence[models.PatchedLabelRequest] + labels: Sequence[models.PatchedLabelRequest] = attrs.field() """ Information about labels that the function supports. @@ -37,6 +42,49 @@ class DetectionFunctionSpec: constructors and help to follow some of the constraints. """ + @staticmethod + def _validate_label_spec(label: models.PatchedLabelRequest) -> None: + if getattr(label, "attributes", None): + raise BadFunctionError(f"label attributes are currently not supported") + + if getattr(label, "sublabels", []): + label_type = getattr(label, "type", "any") + if label_type != "skeleton": + raise BadFunctionError( + f"label {label.name!r} with sublabels has type {label_type!r} (should be 'skeleton')" + ) + + seen_sl_ids = set() + + for sl in label.sublabels: + if not hasattr(sl, "id"): + raise BadFunctionError( + f"sublabel {sl.name!r} of label {label.name!r} has no ID" + ) + + if sl.id in seen_sl_ids: + raise BadFunctionError( + f"sublabel {sl.name!r} of label {label.name!r} has same ID as another sublabel ({sl.id})" + ) + + seen_sl_ids.add(sl.id) + + @labels.validator + def _validate_labels(self, attribute, value: Sequence[models.PatchedLabelRequest]) -> None: + seen_label_ids = set() + + for label in value: + if not hasattr(label, "id"): + raise BadFunctionError(f"label {label.name!r} has no ID") + + if label.id in seen_label_ids: + raise BadFunctionError( + f"label {label.name} has same ID as another label ({label.id})" + ) + seen_label_ids.add(label.id) + + self._validate_label_spec(label) + class DetectionFunctionContext(metaclass=abc.ABCMeta): """ diff --git a/tests/python/sdk/test_auto_annotation.py b/tests/python/sdk/test_auto_annotation.py index 6b1646e03a24..6ad3057e8f53 100644 --- a/tests/python/sdk/test_auto_annotation.py +++ b/tests/python/sdk/test_auto_annotation.py @@ -42,6 +42,89 @@ def _common_setup( api_client.configuration.logger[k] = logger +class TestDetectionFunctionSpec: + def _test_bad_spec(self, exc_match: str, **kwargs) -> None: + with pytest.raises(cvataa.BadFunctionError, match=exc_match): + cvataa.DetectionFunctionSpec(**kwargs) + + def test_attributes(self): + self._test_bad_spec( + "currently not supported", + labels=[ + cvataa.label_spec( + "car", + 123, + attributes=[ + models.AttributeRequest( + "age", + mutable=False, + input_type="number", + values=["0", "100", "1"], + default_value="0", + ) + ], + ), + ], + ) + + def test_label_without_id(self): + self._test_bad_spec( + "label .+ has no ID", + labels=[ + models.PatchedLabelRequest( + name="car", + ), + ], + ) + + def test_duplicate_label_id(self): + self._test_bad_spec( + "same ID as another label", + labels=[ + cvataa.label_spec("car", 123), + cvataa.label_spec("bicycle", 123), + ], + ) + + def test_non_skeleton_sublabels(self): + self._test_bad_spec( + "should be 'skeleton'", + labels=[ + cvataa.label_spec( + "car", + 123, + sublabels=[models.SublabelRequest("wheel", id=1)], + ), + ], + ) + + def test_sublabel_without_id(self): + self._test_bad_spec( + "sublabel .+ of label .+ has no ID", + labels=[ + cvataa.skeleton_label_spec( + "car", + 123, + [models.SublabelRequest("wheel")], + ), + ], + ) + + def test_duplicate_sublabel_id(self): + self._test_bad_spec( + "same ID as another sublabel", + labels=[ + cvataa.skeleton_label_spec( + "cat", + 123, + [ + cvataa.keypoint_spec("head", 1), + cvataa.keypoint_spec("tail", 1), + ], + ), + ], + ) + class TestTaskAutoAnnotation: @pytest.fixture(autouse=True) def setup( @@ -342,119 +425,25 @@ def detect(context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: assert received_cmtp is True - def _test_bad_function_spec(self, spec: cvataa.DetectionFunctionSpec, exc_match: str) -> None: + def _test_spec_dataset_mismatch(self, exc_match: str, spec: cvataa.DetectionFunctionSpec) -> None: def detect(context, image): assert False with pytest.raises(cvataa.BadFunctionError, match=exc_match): cvataa.annotate_task(self.client, self.task.id, namespace(spec=spec, detect=detect)) - def test_attributes(self): - self._test_bad_function_spec( - cvataa.DetectionFunctionSpec( - labels=[ - cvataa.label_spec( - "car", - 123, - attributes=[ - models.AttributeRequest( - "age", - mutable=False, - input_type="number", - values=["0", "100", "1"], - default_value="0", - ) - ], - ), - ], - ), - "currently not supported", - ) - def test_label_not_in_dataset(self): - self._test_bad_function_spec( - cvataa.DetectionFunctionSpec( - labels=[cvataa.label_spec("dog", 123)], - ), + self._test_spec_dataset_mismatch( "not in dataset", - ) - - def test_label_without_id(self): - self._test_bad_function_spec( - cvataa.DetectionFunctionSpec( - labels=[ - models.PatchedLabelRequest( - name="car", - ), - ], - ), - "label .+ has no ID", - ) - - def test_duplicate_label_id(self): - self._test_bad_function_spec( - cvataa.DetectionFunctionSpec( - labels=[ - cvataa.label_spec("car", 123), - cvataa.label_spec("bicycle", 123), - ], - ), - "same ID as another label", - ) - - def test_non_skeleton_sublabels(self): - self._test_bad_function_spec( - cvataa.DetectionFunctionSpec( - labels=[ - cvataa.label_spec( - "car", - 123, - sublabels=[models.SublabelRequest("wheel", id=1)], - ), - ], - ), - "should be 'skeleton'", - ) - - def test_sublabel_without_id(self): - self._test_bad_function_spec( - cvataa.DetectionFunctionSpec( - labels=[ - cvataa.skeleton_label_spec( - "car", - 123, - [models.SublabelRequest("wheel")], - ), - ], - ), - "sublabel .+ of label .+ has no ID", - ) - - def test_duplicate_sublabel_id(self): - self._test_bad_function_spec( - cvataa.DetectionFunctionSpec( - labels=[ - cvataa.skeleton_label_spec( - "cat", - 123, - [ - cvataa.keypoint_spec("head", 1), - cvataa.keypoint_spec("tail", 1), - ], - ), - ], - ), - "same ID as another sublabel", + cvataa.DetectionFunctionSpec(labels=[cvataa.label_spec("dog", 123)]), ) def test_sublabel_not_in_dataset(self): - self._test_bad_function_spec( - cvataa.DetectionFunctionSpec( - labels=[ - cvataa.skeleton_label_spec("cat", 123, [cvataa.keypoint_spec("nose", 1)]), - ], - ), + self._test_spec_dataset_mismatch( "not in dataset", + cvataa.DetectionFunctionSpec(labels=[ + cvataa.skeleton_label_spec("cat", 123, [cvataa.keypoint_spec("nose", 1)]), + ]), ) def _test_bad_function_detect(self, detect, exc_match: str) -> None: