Skip to content

Commit

Permalink
Basic fix for merged dataset annotations
Browse files Browse the repository at this point in the history
Fixes how the merged datasets get annotations for all the task types. Probably, need to change it further to use the original GT, as excluded GT annotations will be missing from the result now.
  • Loading branch information
zhiltsov-max committed Nov 12, 2024
1 parent ba9228f commit 16c67b0
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 14 deletions.
34 changes: 34 additions & 0 deletions packages/examples/cvat/recording-oracle/src/core/tasks/points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from pathlib import Path
from tempfile import TemporaryDirectory

import datumaro as dm

# These details are relevant for image_points tasks


class TaskMetaLayout:
GT_FILENAME = "gt.json"


class TaskMetaSerializer:
GT_DATASET_FORMAT = "datumaro"

def serialize_gt_annotations(self, gt_dataset: dm.Dataset) -> bytes:
with TemporaryDirectory() as temp_dir:
gt_dataset_dir = os.path.join(temp_dir, "gt_dataset")
gt_dataset.export(gt_dataset_dir, self.GT_DATASET_FORMAT)
return (Path(gt_dataset_dir) / "annotations" / "default.json").read_bytes()

def parse_gt_annotations(self, gt_dataset_data: bytes) -> dm.Dataset:
with TemporaryDirectory() as temp_dir:
annotations_dir = os.path.join(temp_dir, "annotations")
os.makedirs(annotations_dir)

annotations_filename = os.path.join(annotations_dir, "default.json")
with open(annotations_filename, "wb") as f:
f.write(gt_dataset_data)

dataset = dm.Dataset.import_from(temp_dir, format=self.GT_DATASET_FORMAT)
dataset.init_cache()
return dataset
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import datumaro as dm

# These details are relevant for image_points and image_boxes tasks
# These details are relevant for image_boxes tasks


class TaskMetaLayout:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import datumaro as dm
import numpy as np

import src.core.tasks.boxes_from_points as boxes_from_points_task
import src.core.tasks.points as points_task
import src.core.tasks.simple as simple_task
import src.core.tasks.skeletons_from_boxes as skeletons_from_boxes_task
import src.cvat.api_calls as cvat_api
import src.services.validation as db_service
from src.core.annotation_meta import AnnotationMeta
Expand All @@ -25,7 +28,7 @@
from src.db.utils import ForUpdateParams
from src.services.cloud import make_client as make_cloud_client
from src.services.cloud.utils import BucketAccessInfo
from src.utils.annotations import ProjectLabels
from src.utils.annotations import ProjectLabels, flatten_points
from src.utils.zip_archive import extract_zip_archive, write_dir_to_zip_archive

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,12 +118,36 @@ def _require_field(self, field: T | None) -> T:
def _gt_key_to_sample_id(self, gt_key: str) -> str:
return gt_key

def _get_meta_layout_and_serializer(self):
if self.manifest.annotation.type == TaskTypes.image_boxes:
return (
simple_task.TaskMetaLayout(),
simple_task.TaskMetaSerializer(),
)
if self.manifest.annotation.type == TaskTypes.image_points:
return (
points_task.TaskMetaLayout(),
points_task.TaskMetaSerializer(),
)
if self.manifest.annotation.type == TaskTypes.image_boxes_from_points:
return (
boxes_from_points_task.TaskMetaLayout(),
boxes_from_points_task.TaskMetaSerializer(),
)
if self.manifest.annotation.type == TaskTypes.image_skeletons_from_boxes:
return (
skeletons_from_boxes_task.TaskMetaLayout(),
skeletons_from_boxes_task.TaskMetaSerializer(),
)
raise AssertionError(f"Unknown task type {self.manifest.annotation.type}")

def _parse_gt(self):
layout = simple_task.TaskMetaLayout()
serializer = simple_task.TaskMetaSerializer()
layout, serializer = self._get_meta_layout_and_serializer()

oracle_data_bucket = BucketAccessInfo.parse_obj(Config.exchange_oracle_storage_config)
storage_client = make_cloud_client(oracle_data_bucket)
exchange_oracle_data_bucket = BucketAccessInfo.parse_obj(
Config.exchange_oracle_storage_config
)
storage_client = make_cloud_client(exchange_oracle_data_bucket)

self._gt_dataset = serializer.parse_gt_annotations(
storage_client.download_file(
Expand Down Expand Up @@ -292,18 +319,17 @@ def _put_gt_into_merged_dataset(
annotations = [
dm.Skeleton(
elements=[
# Put a point in the center of each GT bbox
# Not ideal, but it's the target for now
dm.Points(
[bbox.x + bbox.w / 2, bbox.y + bbox.h / 2],
point.points,
label=point_label_id,
attributes=bbox.attributes,
attributes=point.attributes,
)
],
label=skeleton_label_id,
)
for bbox in sample.annotations
if isinstance(bbox, dm.Bbox)
for point in flatten_points(
[p for p in sample.annotations if isinstance(p, dm.Points)]
)
]
merged_dataset.put(sample.wrap(annotations=annotations))
case TaskTypes.image_label_binary.value:
Expand Down
30 changes: 28 additions & 2 deletions packages/examples/cvat/recording-oracle/src/utils/annotations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
from argparse import ArgumentParser
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from copy import deepcopy

import datumaro as dm
import numpy as np
from datumaro.util import mask_tools
from datumaro.util import filter_dict, mask_tools


def flatten_points(input_points: Sequence[dm.Points]) -> list[dm.Points]:
results = []

for pts in input_points:
for point_idx in range(len(pts.points) // 2):
point_x = pts.points[2 * point_idx + 0]
point_y = pts.points[2 * point_idx + 1]

point_v = pts.visibility[point_idx]
if pts.attributes.get("outside") is True:
point_v = dm.Points.Visibility.absent
elif point_v == dm.Points.Visibility.visible and pts.attributes.get("occluded") is True:
point_v = dm.Points.Visibility.hidden

results.append(
dm.Points(
[point_x, point_y],
visibility=[point_v],
label=pts.label,
attributes=filter_dict(pts.attributes, exclude_keys=["occluded", "outside"]),
)
)

return results


def shift_ann(
Expand Down

0 comments on commit 16c67b0

Please sign in to comment.