Skip to content

Commit

Permalink
Add attribute support for auto-annotation functions (#9090)
Browse files Browse the repository at this point in the history
Remove one of the long-standing limitations on auto-annotation functions
by adding the necessary validation and remapping logic to support
attribute specifications and values. Add a utility module for attributes
with functionality I needed, but felt didn't belong in the
auto-annotation layer.

Adds the necessary code to support using functions with attributes via
agents, as well. I will submit the necesssary server-side code will be
submitted to the private repository later; until that is merged,
attempts to create native functions with attributes will be rejected.
  • Loading branch information
SpecLad authored Feb 19, 2025
1 parent 774b504 commit 79e6eff
Show file tree
Hide file tree
Showing 13 changed files with 1,006 additions and 99 deletions.
8 changes: 8 additions & 0 deletions changelog.d/20250211_145801_roman_aa_attributes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
### Added

- \[SDK\] Auto-annotation detection functions can now output shape/keypoint attributes
(<https://github.com/cvat-ai/cvat/pull/9090>)

- \{SDK\] Added a utility module for working with label attributes,
`cvat_sdk.attributes`
(<https://github.com/cvat-ai/cvat/pull/9090>)
45 changes: 30 additions & 15 deletions cvat-cli/src/cvat_cli/_internal/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from cvat_sdk.auto_annotation.driver import (
_AnnotationMapper,
_DetectionFunctionContextImpl,
_LabelNameMapping,
_SpecNameMapping,
)
from cvat_sdk.exceptions import ApiException
Expand Down Expand Up @@ -145,28 +144,42 @@ def _validate_detection_function_compatibility(self, remote_function: dict) -> N
labels_by_name = {label.name: label for label in self._function_spec.labels}

for remote_label in remote_function["labels_v2"]:
label_desc = f"label {remote_label['name']!r}"
label = labels_by_name.get(remote_label["name"])

if not label:
raise CriticalError(
incompatible_msg + f"label {remote_label['name']!r} is not supported."
)
raise CriticalError(incompatible_msg + f"{label_desc} is not supported.")

if (
remote_label["type"] not in {"any", "unknown"}
and remote_label["type"] != label.type
):
raise CriticalError(
incompatible_msg
+ f"label {remote_label['name']!r} has type {remote_label['type']!r}, "
f"but the function object expects type {label.type!r}."
incompatible_msg + f"{label_desc} has type {remote_label['type']!r}, "
f"but the function object declares type {label.type!r}."
)

if remote_label["attributes"]:
raise CriticalError(
incompatible_msg
+ f"label {remote_label['name']!r} has attributes, which is not supported."
)
attrs_by_name = {attr.name: attr for attr in getattr(label, "attributes", [])}

for remote_attr in remote_label["attributes"]:
attr_desc = f"attribute {remote_attr['name']!r} of {label_desc}"
attr = attrs_by_name.get(remote_attr["name"])

if not attr:
raise CriticalError(incompatible_msg + f"{attr_desc} is not supported.")

if remote_attr["input_type"] != attr.input_type.value:
raise CriticalError(
incompatible_msg
+ f"{attr_desc} has input type {remote_attr['input_type']!r},"
f" but the function object declares input type {attr.input_type.value!r}."
)

if remote_attr["values"] != attr.values:
raise CriticalError(
incompatible_msg + f"{attr_desc} has values {remote_attr['values']!r},"
f" but the function object declares values {attr.values!r}."
)

def _wait_between_polls(self):
# offset the interval randomly to avoid synchronization between workers
Expand Down Expand Up @@ -288,11 +301,13 @@ def _calculate_result_for_detection_ar(
self._update_ar(ar_id, 0)
last_update_timestamp = datetime.now(tz=timezone.utc)

mapping = ar_params["mapping"]
conv_mask_to_poly = ar_params["conv_mask_to_poly"]

spec_nm = _SpecNameMapping(
labels={k: _LabelNameMapping(v["name"]) for k, v in mapping.items()}
spec_nm = _SpecNameMapping.from_api(
{
k: models.LabelMappingEntryRequest._from_openapi_data(**v)
for k, v in ar_params["mapping"].items()
}
)

mapper = _AnnotationMapper(
Expand Down
8 changes: 8 additions & 0 deletions cvat-cli/src/cvat_cli/_internal/commands_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def execute(
remote_function["labels_v2"].append(
{
"name": label_spec.name,
"attributes": [
{
"name": attribute_spec.name,
"input_type": attribute_spec.input_type,
"values": attribute_spec.values,
}
for attribute_spec in getattr(label_spec, "attributes", [])
],
}
)

Expand Down
2 changes: 1 addition & 1 deletion cvat-cli/src/cvat_cli/_internal/commands_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--allow-unmatched-labels",
action="store_true",
help="Allow the function to declare labels not configured in the task",
help="Allow the function to declare labels/sublabels/attributes not configured in the task",
)

parser.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion cvat-sdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ The SDK API includes several layers:
- Server API wrappers (`ApiClient`). Located in at `cvat_sdk.api_client`.
- High-level tools (`Core`). Located at `cvat_sdk.core`.
- PyTorch adapter. Located at `cvat_sdk.pytorch`.
* Auto-annotation support. Located at `cvat_sdk.auto_annotation`.
- Auto-annotation support. Located at `cvat_sdk.auto_annotation`.
- Miscellaneous utilities, grouped by topic.
Located at `cvat_sdk.attributes` and `cvat_sdk.masks`.

Package documentation is available [here](https://docs.cvat.ai/docs/api_sdk/sdk).

Expand Down
136 changes: 136 additions & 0 deletions cvat-sdk/cvat_sdk/attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from __future__ import annotations

from collections.abc import Mapping
from typing import Callable, Union

from . import models


class _CheckboxAttributeValueValidator:
def __init__(self, values: list[str]) -> None:
pass

def __call__(self, value: str) -> bool:
return value in {"true", "false"}


class _NumberAttributeValueValidator:
def __init__(self, values: list[str]) -> None:
if len(values) != 3:
raise ValueError(f"wrong number of values: expected 3, got {len(values)}")

try:
(self._min_value, self._max_value, self._step) = map(int, values)
except ValueError as ex:
raise ValueError(f"values could not be converted to integers") from ex

try:
number_attribute_values(self._min_value, self._max_value, self._step)
except ValueError as ex:
raise ValueError(f"invalid values: {ex}") from ex

def __call__(self, value: str) -> bool:
try:
value = int(value)
except ValueError:
return False

return (
self._min_value <= value <= self._max_value
and (value - self._min_value) % self._step == 0
)


class _SelectAttributeValueValidator:
def __init__(self, values: list[str]) -> None:
if len(values) == 0:
raise ValueError("empty list of allowed values")

self._values = frozenset(values)

def __call__(self, value: str) -> bool:
return value in self._values


class _TextAttributeValueValidator:
def __init__(self, values: list[str]) -> None:
pass

def __call__(self, value: str) -> bool:
return True


_VALIDATOR_CLASSES = {
"checkbox": _CheckboxAttributeValueValidator,
"number": _NumberAttributeValueValidator,
"radio": _SelectAttributeValueValidator,
"select": _SelectAttributeValueValidator,
"text": _TextAttributeValueValidator,
}

# make sure all possible types are covered
assert set(models.InputTypeEnum.allowed_values[("value",)].values()) == _VALIDATOR_CLASSES.keys()


def attribute_value_validator(spec: models.IAttributeRequest) -> Callable[[str], bool]:
"""
Returns a callable that can be used to verify
whether an attribute value is suitable for an attribute with the given spec.
The resulting callable takes a single argument (the attribute value as a string)
and returns True if and only if the value is suitable.
The spec's `values` attribute must be consistent with its `input_type` attribute,
otherwise ValueError will be raised.
"""
return _VALIDATOR_CLASSES[spec.input_type.value](spec.values)


def number_attribute_values(min_value: int, max_value: int, /, step: int = 1) -> list[str]:
"""
Returns a list suitable as the value of the "values" field of an `AttributeRequest`
with `input_type="number"`.
"""

if min_value > max_value:
raise ValueError("min_value must be less than or equal to max_value")

if step <= 0:
raise ValueError("step must be positive")

if (max_value - min_value) % step != 0:
raise ValueError("step must be a divisor of max_value - min_value")

return [str(min_value), str(max_value), str(step)]


def attribute_vals_from_dict(
id_to_value: Mapping[int, Union[str, int, bool]], /
) -> list[models.AttributeValRequest]:
"""
Returns a list of AttributeValRequest objects with given IDs and values.
The input value must be a mapping from attribute spec IDs to corresponding values.
A value may be specified as a string, an integer, or a boolean.
Integers and booleans will be converted to strings according to the format CVAT expects
for attributes with input type "number" and "checkbox", respectively.
"""

def val_as_string(v: Union[str, int, bool]) -> str:
if v is True:
return "true"
if v is False:
return "false"
if isinstance(v, int):
return str(v)
if isinstance(v, str):
return v
assert False, f"unexpected value {v!r} of type {type(v)}"

return [
models.AttributeValRequest(spec_id=k, value=val_as_string(v))
for k, v in id_to_value.items()
]
12 changes: 12 additions & 0 deletions cvat-sdk/cvat_sdk/auto_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,42 @@
DetectionFunction,
DetectionFunctionContext,
DetectionFunctionSpec,
attribute_spec,
checkbox_attribute_spec,
keypoint,
keypoint_spec,
label_spec,
mask,
number_attribute_spec,
polygon,
radio_attribute_spec,
rectangle,
select_attribute_spec,
shape,
skeleton,
skeleton_label_spec,
text_attribute_spec,
)

__all__ = [
"annotate_task",
"attribute_spec",
"BadFunctionError",
"checkbox_attribute_spec",
"DetectionFunction",
"DetectionFunctionContext",
"DetectionFunctionSpec",
"keypoint_spec",
"keypoint",
"label_spec",
"mask",
"number_attribute_spec",
"polygon",
"radio_attribute_spec",
"rectangle",
"select_attribute_spec",
"shape",
"skeleton_label_spec",
"skeleton",
"text_attribute_spec",
]
Loading

0 comments on commit 79e6eff

Please sign in to comment.