Skip to content

Commit

Permalink
Update dataset manager bindings with image info
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max committed Feb 19, 2020
1 parent 482433c commit 1e640a2
Showing 1 changed file with 100 additions and 67 deletions.
167 changes: 100 additions & 67 deletions cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cvat.apps.engine.models import Task, ShapeType, AttributeType

import datumaro.components.extractor as datumaro
from datumaro.util.image import lazy_image
from datumaro.util.image import Image


class CvatImagesDirExtractor(datumaro.Extractor):
Expand All @@ -29,8 +29,7 @@ def __init__(self, url):
path = osp.join(dirpath, name)
if self._is_image(path):
item_id = Task.get_image_frame(path)
item = datumaro.DatasetItem(
id=item_id, image=lazy_image(path))
item = datumaro.DatasetItem(id=item_id, image=path)
items.append((item.id, item))

items = sorted(items, key=lambda e: int(e[0]))
Expand All @@ -49,112 +48,90 @@ def __len__(self):
def subsets(self):
return self._subsets

def get(self, item_id, subset=None, path=None):
if path or subset:
raise KeyError()
return self._items[item_id]

def _is_image(self, path):
for ext in self._SUPPORTED_FORMATS:
if osp.isfile(path) and path.endswith(ext):
return True
return False


class CvatTaskExtractor(datumaro.Extractor):
def __init__(self, url, db_task, user):
self._db_task = db_task
self._categories = self._load_categories()

cvat_annotations = TaskAnnotation(db_task.id, user)
with transaction.atomic():
cvat_annotations.init_from_db()
cvat_annotations = Annotation(cvat_annotations.ir_data, db_task)
class CvatAnnotationsExtractor(datumaro.Extractor):
def __init__(self, url, cvat_annotations):
self._categories = self._load_categories(cvat_annotations)

dm_annotations = []

for cvat_anno in cvat_annotations.group_by_frame():
dm_anno = self._read_cvat_anno(cvat_anno)
dm_item = datumaro.DatasetItem(
id=cvat_anno.frame, annotations=dm_anno)
for cvat_frame_anno in cvat_annotations.group_by_frame():
dm_anno = self._read_cvat_anno(cvat_frame_anno, cvat_annotations)
dm_image = Image(path=cvat_frame_anno.name, size=(
cvat_frame_anno.height, cvat_frame_anno.width)
)
dm_item = datumaro.DatasetItem(id=cvat_frame_anno.frame,
annotations=dm_anno, image=dm_image)
dm_annotations.append((dm_item.id, dm_item))

dm_annotations = sorted(dm_annotations, key=lambda e: int(e[0]))
self._items = OrderedDict(dm_annotations)

self._subsets = None

def __iter__(self):
for item in self._items.values():
yield item

def __len__(self):
return len(self._items)

# pylint: disable=no-self-use
def subsets(self):
return self._subsets
return []
# pylint: enable=no-self-use

def get(self, item_id, subset=None, path=None):
if path or subset:
raise KeyError()
return self._items[item_id]
def categories(self):
return self._categories

def _load_categories(self):
@staticmethod
def _load_categories(cvat_anno):
categories = {}
label_categories = datumaro.LabelCategories()

db_labels = self._db_task.label_set.all()
for db_label in db_labels:
db_attributes = db_label.attributespec_set.all()
label_categories.add(db_label.name)

for db_attr in db_attributes:
label_categories.attributes.add(db_attr.name)
for _, label in cvat_anno.meta['task']['labels']:
label_categories.add(label['name'])
for _, attr in label['attributes']:
label_categories.attributes.add(attr['name'])

categories[datumaro.AnnotationType.label] = label_categories

return categories

def categories(self):
return self._categories

def _read_cvat_anno(self, cvat_anno):
def _read_cvat_anno(self, cvat_frame_anno, cvat_task_anno):
item_anno = []

categories = self.categories()
label_cat = categories[datumaro.AnnotationType.label]

label_map = {}
label_attrs = {}
db_labels = self._db_task.label_set.all()
for db_label in db_labels:
label_map[db_label.name] = label_cat.find(db_label.name)[0]

attrs = {}
db_attributes = db_label.attributespec_set.all()
for db_attr in db_attributes:
attrs[db_attr.name] = db_attr
label_attrs[db_label.name] = attrs
map_label = lambda label_db_name: label_map[label_db_name]
map_label = lambda name: label_cat.find(name)[0]
label_attrs = {
label['name']: label['attributes']
for _, label in cvat_task_anno.meta['task']['labels']
}

def convert_attrs(label, cvat_attrs):
cvat_attrs = {a.name: a.value for a in cvat_attrs}
dm_attr = dict()
for attr_name, attr_spec in label_attrs[label].items():
attr_value = cvat_attrs.get(attr_name, attr_spec.default_value)
for _, a_desc in label_attrs[label]:
a_name = a_desc['name']
a_value = cvat_attrs.get(a_name, a_desc['default_value'])
try:
if attr_spec.input_type == AttributeType.NUMBER:
attr_value = float(attr_value)
elif attr_spec.input_type == AttributeType.CHECKBOX:
attr_value = attr_value.lower() == 'true'
dm_attr[attr_name] = attr_value
if a_desc['input_type'] == AttributeType.NUMBER:
a_value = float(a_value)
elif a_desc['input_type'] == AttributeType.CHECKBOX:
a_value = (a_value.lower() == 'true')
dm_attr[a_name] = a_value
except Exception as e:
slogger.task[self._db_task.id].error(
"Failed to convert attribute '%s'='%s': %s" % \
(attr_name, attr_value, e))
raise Exception(
"Failed to convert attribute '%s'='%s': %s" %
(a_name, a_value, e))
return dm_attr

for tag_obj in cvat_anno.tags:
for tag_obj in cvat_frame_anno.tags:
anno_group = tag_obj.group
anno_label = map_label(tag_obj.label)
anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes)
Expand All @@ -163,7 +140,7 @@ def convert_attrs(label, cvat_attrs):
attributes=anno_attr, group=anno_group)
item_anno.append(anno)

for shape_obj in cvat_anno.labeled_shapes:
for shape_obj in cvat_frame_anno.labeled_shapes:
anno_group = shape_obj.group
anno_label = map_label(shape_obj.label)
anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes)
Expand All @@ -183,8 +160,64 @@ def convert_attrs(label, cvat_attrs):
anno = datumaro.Bbox(x0, y0, x1 - x0, y1 - y0,
label=anno_label, attributes=anno_attr, group=anno_group)
else:
raise Exception("Unknown shape type '%s'" % (shape_obj.type))
raise Exception("Unknown shape type '%s'" % shape_obj.type)

item_anno.append(anno)

return item_anno
return item_anno


class CvatTaskExtractor(CvatAnnotationsExtractor):
def __init__(self, url, db_task, user):
cvat_annotations = TaskAnnotation(db_task.id, user)
with transaction.atomic():
cvat_annotations.init_from_db()
cvat_annotations = Annotation(cvat_annotations.ir_data, db_task)
super().__init__(url, cvat_annotations)


def match_frame(item, cvat_task_anno):
frame_number = None
if frame_number is None:
try:
frame_number = cvat_task_anno.match_frame(item.id)
except Exception:
pass
if frame_number is None and item.has_image:
try:
frame_number = cvat_task_anno.match_frame(item.image.filename)
except Exception:
pass
if frame_number is None:
try:
frame_number = int(item.id)
except Exception:
pass
if not frame_number in cvat_task_anno.frame_info:
raise Exception("Could not match item id: '%s' with any task frame" %
item.id)
return frame_number

def import_dm_annotations(dm_dataset, cvat_task_anno):
shapes = {
datumaro.AnnotationType.bbox: ShapeType.RECTANGLE,
datumaro.AnnotationType.polygon: ShapeType.POLYGON,
datumaro.AnnotationType.polyline: ShapeType.POLYLINE,
datumaro.AnnotationType.points: ShapeType.POINTS,
}

label_cat = dm_dataset.categories()[datumaro.AnnotationType.label]

for item in dm_dataset:
frame_number = match_frame(item, cvat_task_anno)

for ann in item.annotations:
if ann.type in shapes:
cvat_task_anno.add_shape(cvat_task_anno.LabeledShape(
type=shapes[ann.type],
frame=frame_number,
label=label_cat.items[ann.label].name,
points=ann.points,
occluded=False,
attributes=[],
))

0 comments on commit 1e640a2

Please sign in to comment.