Skip to content

Commit

Permalink
Refactor _AnnotationMapper
Browse files Browse the repository at this point in the history
This has two main goals:

1. Move all validation that does not depend on the task we're annotating
   from `_AnnotationMapper` and to `DetectorFunctionSpec`. This a) splits
   the code into more manageable chunks, and b) works better with the agent
   workflow, because the CLI `function` commands can reject an invalid
   function spec immediately (rather than waiting until the first request).

   (Frankly, I don't know why I didn't put this code in
   `DetectorFunctionSpec` in the first place, given that it implements the
   restrictions described in the docstring...)

   Validating the spec upon construction requires that the spec is not
   modified afterwards. This has always been my intention, but to make it
   clearer, state it explicitly in the docstring.

2. Restructure the rest of the code into smaller functions. This should not
   introduce any differences in behavior.
  • Loading branch information
SpecLad committed Jan 30, 2025
1 parent e7ce0c4 commit 67e00a2
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 249 deletions.
3 changes: 2 additions & 1 deletion cvat-sdk/cvat_sdk/auto_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#
# SPDX-License-Identifier: MIT

from .driver import BadFunctionError, annotate_task
from .driver import annotate_task
from .exceptions import BadFunctionError
from .interface import (
DetectionFunction,
DetectionFunctionContext,
Expand Down
289 changes: 143 additions & 146 deletions cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,24 @@
#
# SPDX-License-Identifier: MIT

from __future__ import annotations

import logging
from collections.abc import Mapping, Sequence
from typing import Optional

import attrs
from typing_extensions import TypeAlias

import cvat_sdk.models as models
from cvat_sdk.core import Client
from cvat_sdk.core.progress import NullProgressReporter, ProgressReporter
from cvat_sdk.datasets.task_dataset import TaskDataset

from .exceptions import BadFunctionError
from .interface import DetectionFunction, DetectionFunctionContext, DetectionFunctionSpec


class BadFunctionError(Exception):
"""
An exception that signifies that an auto-detection function has violated some constraint
set by its interface.
"""


@attrs.frozen
class _SublabelNameMapping:
name: str
Expand Down Expand Up @@ -53,66 +50,35 @@ def map_label(self, name: str):


class _AnnotationMapper:
_SublabelIdMapping: TypeAlias = int

@attrs.frozen
class _LabelIdMapping:
id: int
sublabels: Mapping[int, Optional[int]]
sublabels: Mapping[int, Optional[_AnnotationMapper._SublabelIdMapping]]
expected_num_elements: int = 0

_label_id_mappings: Mapping[int, Optional[_LabelIdMapping]]
_SpecIdMapping: TypeAlias = Mapping[int, Optional[_LabelIdMapping]]

_spec_id_mapping: _SpecIdMapping

def _build_label_id_mapping(
self,
fun_label: models.ILabel,
ds_labels_by_name: Mapping[str, models.ILabel],
ds_label: models.ILabel,
*,
label_nm: _LabelNameMapping,
allow_unmatched_labels: bool,
spec_nm: _SpecNameMapping,
) -> Optional[_LabelIdMapping]:
if getattr(fun_label, "attributes", None):
raise BadFunctionError(f"label attributes are currently not supported")

label_nm = spec_nm.map_label(fun_label.name)
if label_nm is None:
return None

ds_label = ds_labels_by_name.get(label_nm.name)
if ds_label is None:
if not allow_unmatched_labels:
raise BadFunctionError(f"label {fun_label.name!r} is not in dataset")

self._logger.info(
"label %r is not in dataset; any annotations using it will be ignored",
fun_label.name,
)
return None

sl_map = {}

if getattr(fun_label, "sublabels", []):
fun_label_type = getattr(fun_label, "type", "any")
if fun_label_type != "skeleton":
raise BadFunctionError(
f"label {fun_label.name!r} with sublabels has type {fun_label_type!r} (should be 'skeleton')"
)

ds_sublabels_by_name = {ds_sl.name: ds_sl for ds_sl in ds_label.sublabels}

for fun_sl in fun_label.sublabels:
if not hasattr(fun_sl, "id"):
raise BadFunctionError(
f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has no ID"
)

if fun_sl.id in sl_map:
raise BadFunctionError(
f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has same ID as another sublabel ({fun_sl.id})"
)

def sublabel_mapping(fun_sl: models.ILabel) -> Optional[int]:
sublabel_nm = label_nm.map_sublabel(fun_sl.name)
if sublabel_nm is None:
sl_map[fun_sl.id] = None
continue
return None

ds_sl = ds_sublabels_by_name.get(sublabel_nm.name)
if not ds_sl:
Expand All @@ -126,151 +92,182 @@ def _build_label_id_mapping(
fun_sl.name,
fun_label.name,
)
sl_map[fun_sl.id] = None
continue
return None

return ds_sl.id

sl_map[fun_sl.id] = ds_sl.id
sl_map = {fun_sl.id: sublabel_mapping(fun_sl) for fun_sl in fun_label.sublabels}

return self._LabelIdMapping(
ds_label.id, sublabels=sl_map, expected_num_elements=len(ds_label.sublabels)
)

def __init__(
def _build_spec_id_mapping(
self,
logger: logging.Logger,
fun_labels: Sequence[models.ILabel],
ds_labels: Sequence[models.ILabel],
*,
spec_nm: _SpecNameMapping,
allow_unmatched_labels: bool,
conv_mask_to_poly: bool,
spec_nm: _SpecNameMapping = _SpecNameMapping(),
) -> None:
self._logger = logger
self._conv_mask_to_poly = conv_mask_to_poly

) -> _SpecIdMapping:
ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels}

self._label_id_mappings = {}
def label_id_mapping(fun_label: models.ILabel) -> Optional[self._LabelIdMapping]:
label_nm = spec_nm.map_label(fun_label.name)
if label_nm is None:
return None

for fun_label in fun_labels:
if not hasattr(fun_label, "id"):
raise BadFunctionError(f"label {fun_label.name!r} has no ID")
ds_label = ds_labels_by_name.get(label_nm.name)
if ds_label is None:
if not allow_unmatched_labels:
raise BadFunctionError(f"label {fun_label.name!r} is not in dataset")

if fun_label.id in self._label_id_mappings:
raise BadFunctionError(
f"label {fun_label.name} has same ID as another label ({fun_label.id})"
self._logger.info(
"label %r is not in dataset; any annotations using it will be ignored",
fun_label.name,
)
return None

self._label_id_mappings[fun_label.id] = self._build_label_id_mapping(
return self._build_label_id_mapping(
fun_label,
ds_labels_by_name,
ds_label,
label_nm=label_nm,
allow_unmatched_labels=allow_unmatched_labels,
spec_nm=spec_nm,
)

def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: int) -> None:
new_shapes = []
return {fun_label.id: label_id_mapping(fun_label) for fun_label in fun_labels}

def __init__(
self,
logger: logging.Logger,
fun_labels: Sequence[models.ILabel],
ds_labels: Sequence[models.ILabel],
*,
allow_unmatched_labels: bool,
conv_mask_to_poly: bool,
spec_nm: _SpecNameMapping = _SpecNameMapping(),
) -> None:
self._logger = logger
self._conv_mask_to_poly = conv_mask_to_poly

for shape in shapes:
if hasattr(shape, "id"):
raise BadFunctionError("function output shape with preset id")
self._spec_id_mapping = self._build_spec_id_mapping(
fun_labels, ds_labels, spec_nm=spec_nm, allow_unmatched_labels=allow_unmatched_labels
)

if hasattr(shape, "source"):
raise BadFunctionError("function output shape with preset source")
shape.source = "auto"
def _remap_element(
self,
element: models.SubLabeledShapeRequest,
ds_frame: int,
label_id_mapping: _LabelIdMapping,
seen_sl_ids: set[int],
) -> bool:
if hasattr(element, "id"):
raise BadFunctionError("function output shape element with preset id")

if hasattr(element, "source"):
raise BadFunctionError("function output shape element with preset source")
element.source = "auto"

if element.frame != 0:
raise BadFunctionError(
f"function output shape element with unexpected frame number ({element.frame})"
)

if shape.frame != 0:
raise BadFunctionError(
f"function output shape with unexpected frame number ({shape.frame})"
)
element.frame = ds_frame

shape.frame = ds_frame
if element.type.value != "points":
raise BadFunctionError(
f"function output skeleton with element type other than 'points' ({element.type.value})"
)

try:
label_id_mapping = self._label_id_mappings[shape.label_id]
except KeyError:
raise BadFunctionError(
f"function output shape with unknown label ID ({shape.label_id})"
)
try:
mapped_sl_id = label_id_mapping.sublabels[element.label_id]
except KeyError:
raise BadFunctionError(
f"function output shape with unknown sublabel ID ({element.label_id})"
)

if not label_id_mapping:
continue
if not mapped_sl_id:
return False

shape.label_id = label_id_mapping.id
if mapped_sl_id in seen_sl_ids:
raise BadFunctionError(
"function output skeleton with multiple elements with same sublabel"
)

if getattr(shape, "attributes", None):
raise BadFunctionError(
"function output shape with attributes, which is not yet supported"
)
element.label_id = mapped_sl_id

new_shapes.append(shape)
seen_sl_ids.add(mapped_sl_id)

if shape.type.value == "skeleton":
new_elements = []
seen_sl_ids = set()
return True

for element in shape.elements:
if hasattr(element, "id"):
raise BadFunctionError("function output shape element with preset id")
def _remap_elements(
self, shape: models.LabeledShapeRequest, ds_frame: int, label_id_mapping: _LabelIdMapping
) -> None:
if shape.type.value == "skeleton":
seen_sl_ids = set()

if hasattr(element, "source"):
raise BadFunctionError("function output shape element with preset source")
element.source = "auto"
shape.elements[:] = [
element
for element in shape.elements
if self._remap_element(element, ds_frame, label_id_mapping, seen_sl_ids)
]

if element.frame != 0:
raise BadFunctionError(
f"function output shape element with unexpected frame number ({element.frame})"
)
if len(shape.elements) != label_id_mapping.expected_num_elements:
# There could only be fewer elements than expected,
# because the reverse would imply that there are more distinct sublabel IDs
# than are actually defined in the dataset.
assert len(shape.elements) < label_id_mapping.expected_num_elements

element.frame = ds_frame
raise BadFunctionError(
"function output skeleton with fewer elements than expected"
f" ({len(shape.elements)} vs {label_id_mapping.expected_num_elements})"
)
else:
if getattr(shape, "elements", None):
raise BadFunctionError("function output non-skeleton shape with elements")

if element.type.value != "points":
raise BadFunctionError(
f"function output skeleton with element type other than 'points' ({element.type.value})"
)
def _remap_shape(self, shape: models.LabeledShapeRequest, ds_frame: int) -> bool:
if hasattr(shape, "id"):
raise BadFunctionError("function output shape with preset id")

try:
mapped_sl_id = label_id_mapping.sublabels[element.label_id]
except KeyError:
raise BadFunctionError(
f"function output shape with unknown sublabel ID ({element.label_id})"
)
if hasattr(shape, "source"):
raise BadFunctionError("function output shape with preset source")
shape.source = "auto"

if not mapped_sl_id:
continue
if shape.frame != 0:
raise BadFunctionError(
f"function output shape with unexpected frame number ({shape.frame})"
)

if mapped_sl_id in seen_sl_ids:
raise BadFunctionError(
"function output skeleton with multiple elements with same sublabel"
)
shape.frame = ds_frame

element.label_id = mapped_sl_id
try:
label_id_mapping = self._spec_id_mapping[shape.label_id]
except KeyError:
raise BadFunctionError(
f"function output shape with unknown label ID ({shape.label_id})"
)

seen_sl_ids.add(mapped_sl_id)
if not label_id_mapping:
return False

new_elements.append(element)
shape.label_id = label_id_mapping.id

if len(new_elements) != label_id_mapping.expected_num_elements:
# new_elements could only be shorter than expected,
# because the reverse would imply that there are more distinct sublabel IDs
# than are actually defined in the dataset.
assert len(new_elements) < label_id_mapping.expected_num_elements
if shape.type.value == "mask" and self._conv_mask_to_poly:
raise BadFunctionError("function output mask shape despite conv_mask_to_poly=True")

raise BadFunctionError(
f"function output skeleton with fewer elements than expected ({len(new_elements)} vs {label_id_mapping.expected_num_elements})"
)
if getattr(shape, "attributes", None):
raise BadFunctionError(
"function output shape with attributes, which is not yet supported"
)

shape.elements[:] = new_elements
else:
if getattr(shape, "elements", None):
raise BadFunctionError("function output non-skeleton shape with elements")
self._remap_elements(shape, ds_frame, label_id_mapping)

if shape.type.value == "mask" and self._conv_mask_to_poly:
raise BadFunctionError(
"function output mask shape despite conv_mask_to_poly=True"
)
return True

shapes[:] = new_shapes
def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: int) -> None:
shapes[:] = [shape for shape in shapes if self._remap_shape(shape, ds_frame)]


@attrs.frozen(kw_only=True)
Expand Down
Loading

0 comments on commit 67e00a2

Please sign in to comment.