Skip to content
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.

Commit

Permalink
Fix skeleton tracks (cvat-ai#6075)
Browse files Browse the repository at this point in the history
Currently, we don't have validation of incoming annotations, as a
result, there is exist some cases when annotations successfully saved in
DB, but it's impossible to export them. In order to successfully export
a dataset with a skeleton track it's required that each track satisfy
the following condition:
` {frame number of track} == {frame number of parent track} == {frame
number of the first shape of the track}`

This PR adds an additional step during saving annotation in DB. This
additional step check that all these there "frame numbers" are equal and
try to automatically fix it's not true.
  • Loading branch information
Kirill Sizov authored and mikhail-treskin committed Jul 1, 2023
1 parent 1a1eac8 commit 63c193b
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/helm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ jobs:
# They are still tested without Helm
run: |
kubectl cp tests/mounted_file_share/images $(kubectl get pods -l component=server -o jsonpath='{.items[0].metadata.name}'):/home/django/share
pytest --timeout 30 --platform=kube -m "not with_external_services" tests/python
pytest --timeout 30 --platform=kube -m "not with_external_services" tests/python --log-cli-level DEBUG
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ without use_cache option (<https://github.com/opencv/cvat/pull/6074>)

### Fixed
- Skeletons dumping on created tasks/projects (<https://github.com/opencv/cvat/pull/6157>)
- Fix saving annotations for skeleton tracks (<https://github.com/opencv/cvat/pull/6075>)

### Security
- TDB
Expand Down
162 changes: 113 additions & 49 deletions cvat/apps/dataset_manager/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
#
# SPDX-License-Identifier: MIT

import os
from collections import OrderedDict
from copy import deepcopy
from enum import Enum
import os
from tempfile import TemporaryDirectory

from django.db import transaction
from django.db.models.query import Prefetch
from django.utils import timezone
from rest_framework.exceptions import ValidationError

from cvat.apps.engine import models, serializers
from cvat.apps.engine.plugins import plugin_decorator
from cvat.apps.profiler import silk_profile

from .annotation import AnnotationIR, AnnotationManager
from .bindings import TaskData, JobData
from .bindings import JobData, TaskData
from .formats.registry import make_exporter, make_importer
from .util import bulk_create

Expand Down Expand Up @@ -115,47 +117,108 @@ def __init__(self, pk, is_prefetched=False):
def reset(self):
self.ir_data.reset()

def _validate_attribute_for_existence(self, db_attr_val, label_id, attr_type):
if db_attr_val.spec_id not in self.db_attributes[label_id][attr_type]:
raise ValidationError("spec_id `{}` is invalid".format(db_attr_val.spec_id))

def _validate_label_for_existence(self, label_id):
if label_id not in self.db_labels:
raise ValidationError("label_id `{}` is invalid".format(label_id))

def _add_missing_shape(self, track, first_shape):
if first_shape["type"] == "skeleton":
# in case with skeleton track we always expect to see one shape in track
first_shape["frame"] = track["frame"]
else:
missing_shape = deepcopy(first_shape)
missing_shape["frame"] = track["frame"]
missing_shape["outside"] = True
track["shapes"].append(missing_shape)

def _correct_frame_of_tracked_shapes(self, track):
shapes = sorted(track["shapes"], key=lambda a: a["frame"])
first_shape = shapes[0] if shapes else None

if first_shape and track["frame"] < first_shape["frame"]:
self._add_missing_shape(track, first_shape)
elif first_shape and first_shape["frame"] < track["frame"]:
track["frame"] = first_shape["frame"]

def _sync_frames(self, tracks, parent_track):
if not tracks:
return

min_frame = tracks[0]["frame"]

for track in tracks:
if parent_track and parent_track.frame < track["frame"]:
track["frame"] = parent_track.frame

# track and its first shape must have the same frame
self._correct_frame_of_tracked_shapes(track)

if track["frame"] < min_frame:
min_frame = track["frame"]

if not parent_track:
return

if min_frame < parent_track.frame:
# parent track cannot have a frame greater than the frame of the child track
parent_tracked_shape = parent_track.trackedshape_set.first()
parent_track.frame = min_frame
parent_tracked_shape.frame = min_frame

parent_tracked_shape.save()
parent_track.save()

for track in tracks:
if parent_track.frame < track["frame"]:
track["frame"] = parent_track.frame

self._correct_frame_of_tracked_shapes(track)

def _save_tracks_to_db(self, tracks):

def create_tracks(tracks, parent_track=None):
db_tracks = []
db_track_attrvals = []
db_track_attr_vals = []
db_shapes = []
db_shape_attrvals = []
db_shape_attr_vals = []

self._sync_frames(tracks, parent_track)

for track in tracks:
track_attributes = track.pop("attributes", [])
shapes = track.pop("shapes")
elements = track.pop("elements", [])
db_track = models.LabeledTrack(job=self.db_job, parent=parent_track, **track)
if db_track.label_id not in self.db_labels:
raise AttributeError("label_id `{}` is invalid".format(db_track.label_id))

self._validate_label_for_existence(db_track.label_id)

for attr in track_attributes:
db_attrval = models.LabeledTrackAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["immutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.track_id = len(db_tracks)
db_track_attrvals.append(db_attrval)
db_attr_val = models.LabeledTrackAttributeVal(**attr, track_id=len(db_tracks))

self._validate_attribute_for_existence(db_attr_val, db_track.label_id, "immutable")

for shape in shapes:
db_track_attr_vals.append(db_attr_val)

for shape_idx, shape in enumerate(shapes):
shape_attributes = shape.pop("attributes", [])
# FIXME: need to clamp points (be sure that all of them inside the image)
# Should we check here or implement a validator?
db_shape = models.TrackedShape(**shape)
db_shape.track_id = len(db_tracks)
db_shape = models.TrackedShape(**shape, track_id=len(db_tracks))

for attr in shape_attributes:
db_attrval = models.TrackedShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["mutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.shape_id = len(db_shapes)
db_shape_attrvals.append(db_attrval)
db_attr_val = models.TrackedShapeAttributeVal(**attr, shape_id=len(db_shapes))

self._validate_attribute_for_existence(db_attr_val, db_track.label_id, "mutable")

db_shape_attr_vals.append(db_attr_val)

db_shapes.append(db_shape)
shape["attributes"] = shape_attributes

db_tracks.append(db_track)

track["attributes"] = track_attributes
track["shapes"] = shapes
if elements or parent_track is None:
Expand All @@ -167,11 +230,12 @@ def create_tracks(tracks, parent_track=None):
flt_param={"job_id": self.db_job.id}
)

for db_attrval in db_track_attrvals:
db_attrval.track_id = db_tracks[db_attrval.track_id].id
for db_attr_val in db_track_attr_vals:
db_attr_val.track_id = db_tracks[db_attr_val.track_id].id

bulk_create(
db_model=models.LabeledTrackAttributeVal,
objects=db_track_attrvals,
objects=db_track_attr_vals,
flt_param={}
)

Expand All @@ -184,12 +248,12 @@ def create_tracks(tracks, parent_track=None):
flt_param={"track__job_id": self.db_job.id}
)

for db_attrval in db_shape_attrvals:
db_attrval.shape_id = db_shapes[db_attrval.shape_id].id
for db_attr_val in db_shape_attr_vals:
db_attr_val.shape_id = db_shapes[db_attr_val.shape_id].id

bulk_create(
db_model=models.TrackedShapeAttributeVal,
objects=db_shape_attrvals,
objects=db_shape_attr_vals,
flt_param={}
)

Expand All @@ -208,24 +272,23 @@ def create_tracks(tracks, parent_track=None):
def _save_shapes_to_db(self, shapes):
def create_shapes(shapes, parent_shape=None):
db_shapes = []
db_attrvals = []
db_attr_vals = []

for shape in shapes:
attributes = shape.pop("attributes", [])
shape_elements = shape.pop("elements", [])
# FIXME: need to clamp points (be sure that all of them inside the image)
# Should we check here or implement a validator?
db_shape = models.LabeledShape(job=self.db_job, parent=parent_shape, **shape)
if db_shape.label_id not in self.db_labels:
raise AttributeError("label_id `{}` is invalid".format(db_shape.label_id))

self._validate_label_for_existence(db_shape.label_id)

for attr in attributes:
db_attrval = models.LabeledShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes[db_shape.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attr_val = models.LabeledShapeAttributeVal(**attr, shape_id=len(db_shapes))

db_attrval.shape_id = len(db_shapes)
db_attrvals.append(db_attrval)
self._validate_attribute_for_existence(db_attr_val, db_shape.label_id, "all")

db_attr_vals.append(db_attr_val)

db_shapes.append(db_shape)
shape["attributes"] = attributes
Expand All @@ -238,12 +301,12 @@ def create_shapes(shapes, parent_shape=None):
flt_param={"job_id": self.db_job.id}
)

for db_attrval in db_attrvals:
db_attrval.shape_id = db_shapes[db_attrval.shape_id].id
for db_attr_val in db_attr_vals:
db_attr_val.shape_id = db_shapes[db_attr_val.shape_id].id

bulk_create(
db_model=models.LabeledShapeAttributeVal,
objects=db_attrvals,
objects=db_attr_vals,
flt_param={}
)

Expand All @@ -257,20 +320,21 @@ def create_shapes(shapes, parent_shape=None):

def _save_tags_to_db(self, tags):
db_tags = []
db_attrvals = []
db_attr_vals = []

for tag in tags:
attributes = tag.pop("attributes", [])
db_tag = models.LabeledImage(job=self.db_job, **tag)
if db_tag.label_id not in self.db_labels:
raise AttributeError("label_id `{}` is invalid".format(db_tag.label_id))

self._validate_label_for_existence(db_tag.label_id)

for attr in attributes:
db_attrval = models.LabeledImageAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes[db_tag.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.tag_id = len(db_tags)
db_attrvals.append(db_attrval)
db_attr_val = models.LabeledImageAttributeVal(**attr)

self._validate_attribute_for_existence(db_attr_val, db_tag.label_id, "all")

db_attr_val.tag_id = len(db_tags)
db_attr_vals.append(db_attr_val)

db_tags.append(db_tag)
tag["attributes"] = attributes
Expand All @@ -281,12 +345,12 @@ def _save_tags_to_db(self, tags):
flt_param={"job_id": self.db_job.id}
)

for db_attrval in db_attrvals:
db_attrval.image_id = db_tags[db_attrval.tag_id].id
for db_attr_val in db_attr_vals:
db_attr_val.image_id = db_tags[db_attr_val.tag_id].id

bulk_create(
db_model=models.LabeledImageAttributeVal,
objects=db_attrvals,
objects=db_attr_vals,
flt_param={}
)

Expand Down
2 changes: 1 addition & 1 deletion cvat/apps/engine/tests/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4538,7 +4538,7 @@ def _run_api_v2_jobs_id_annotations(self, owner, assignee, annotator):
]
},
{
"frame": 1,
"frame": 2,
"label_id": task["labels"][1]["id"],
"group": None,
"source": "manual",
Expand Down
Loading

0 comments on commit 63c193b

Please sign in to comment.