diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 09f0a698e999..0c984e330cf9 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -14,7 +14,7 @@ from cvat.apps.engine.models import Task, ShapeType, AttributeType import datumaro.components.extractor as datumaro -from datumaro.util.image import lazy_image +from datumaro.util.image import Image class CvatImagesDirExtractor(datumaro.Extractor): @@ -29,8 +29,7 @@ def __init__(self, url): path = osp.join(dirpath, name) if self._is_image(path): item_id = Task.get_image_frame(path) - item = datumaro.DatasetItem( - id=item_id, image=lazy_image(path)) + item = datumaro.DatasetItem(id=item_id, image=path) items.append((item.id, item)) items = sorted(items, key=lambda e: int(e[0])) @@ -49,11 +48,6 @@ def __len__(self): def subsets(self): return self._subsets - def get(self, item_id, subset=None, path=None): - if path or subset: - raise KeyError() - return self._items[item_id] - def _is_image(self, path): for ext in self._SUPPORTED_FORMATS: if osp.isfile(path) and path.endswith(ext): @@ -61,29 +55,24 @@ def _is_image(self, path): return False -class CvatTaskExtractor(datumaro.Extractor): - def __init__(self, url, db_task, user): - self._db_task = db_task - self._categories = self._load_categories() - - cvat_annotations = TaskAnnotation(db_task.id, user) - with transaction.atomic(): - cvat_annotations.init_from_db() - cvat_annotations = Annotation(cvat_annotations.ir_data, db_task) +class CvatAnnotationsExtractor(datumaro.Extractor): + def __init__(self, url, cvat_annotations): + self._categories = self._load_categories(cvat_annotations) dm_annotations = [] - for cvat_anno in cvat_annotations.group_by_frame(): - dm_anno = self._read_cvat_anno(cvat_anno) - dm_item = datumaro.DatasetItem( - id=cvat_anno.frame, annotations=dm_anno) + for cvat_frame_anno in cvat_annotations.group_by_frame(): + dm_anno = self._read_cvat_anno(cvat_frame_anno, cvat_annotations) + dm_image = Image(path=cvat_frame_anno.name, size=( + cvat_frame_anno.height, cvat_frame_anno.width) + ) + dm_item = datumaro.DatasetItem(id=cvat_frame_anno.frame, + annotations=dm_anno, image=dm_image) dm_annotations.append((dm_item.id, dm_item)) dm_annotations = sorted(dm_annotations, key=lambda e: int(e[0])) self._items = OrderedDict(dm_annotations) - self._subsets = None - def __iter__(self): for item in self._items.values(): yield item @@ -91,70 +80,58 @@ def __iter__(self): def __len__(self): return len(self._items) + # pylint: disable=no-self-use def subsets(self): - return self._subsets + return [] + # pylint: enable=no-self-use - def get(self, item_id, subset=None, path=None): - if path or subset: - raise KeyError() - return self._items[item_id] + def categories(self): + return self._categories - def _load_categories(self): + @staticmethod + def _load_categories(cvat_anno): categories = {} label_categories = datumaro.LabelCategories() - db_labels = self._db_task.label_set.all() - for db_label in db_labels: - db_attributes = db_label.attributespec_set.all() - label_categories.add(db_label.name) - - for db_attr in db_attributes: - label_categories.attributes.add(db_attr.name) + for _, label in cvat_anno.meta['task']['labels']: + label_categories.add(label['name']) + for _, attr in label['attributes']: + label_categories.attributes.add(attr['name']) categories[datumaro.AnnotationType.label] = label_categories return categories - def categories(self): - return self._categories - - def _read_cvat_anno(self, cvat_anno): + def _read_cvat_anno(self, cvat_frame_anno, cvat_task_anno): item_anno = [] categories = self.categories() label_cat = categories[datumaro.AnnotationType.label] - - label_map = {} - label_attrs = {} - db_labels = self._db_task.label_set.all() - for db_label in db_labels: - label_map[db_label.name] = label_cat.find(db_label.name)[0] - - attrs = {} - db_attributes = db_label.attributespec_set.all() - for db_attr in db_attributes: - attrs[db_attr.name] = db_attr - label_attrs[db_label.name] = attrs - map_label = lambda label_db_name: label_map[label_db_name] + map_label = lambda name: label_cat.find(name)[0] + label_attrs = { + label['name']: label['attributes'] + for _, label in cvat_task_anno.meta['task']['labels'] + } def convert_attrs(label, cvat_attrs): cvat_attrs = {a.name: a.value for a in cvat_attrs} dm_attr = dict() - for attr_name, attr_spec in label_attrs[label].items(): - attr_value = cvat_attrs.get(attr_name, attr_spec.default_value) + for _, a_desc in label_attrs[label]: + a_name = a_desc['name'] + a_value = cvat_attrs.get(a_name, a_desc['default_value']) try: - if attr_spec.input_type == AttributeType.NUMBER: - attr_value = float(attr_value) - elif attr_spec.input_type == AttributeType.CHECKBOX: - attr_value = attr_value.lower() == 'true' - dm_attr[attr_name] = attr_value + if a_desc['input_type'] == AttributeType.NUMBER: + a_value = float(a_value) + elif a_desc['input_type'] == AttributeType.CHECKBOX: + a_value = (a_value.lower() == 'true') + dm_attr[a_name] = a_value except Exception as e: - slogger.task[self._db_task.id].error( - "Failed to convert attribute '%s'='%s': %s" % \ - (attr_name, attr_value, e)) + raise Exception( + "Failed to convert attribute '%s'='%s': %s" % + (a_name, a_value, e)) return dm_attr - for tag_obj in cvat_anno.tags: + for tag_obj in cvat_frame_anno.tags: anno_group = tag_obj.group anno_label = map_label(tag_obj.label) anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes) @@ -163,7 +140,7 @@ def convert_attrs(label, cvat_attrs): attributes=anno_attr, group=anno_group) item_anno.append(anno) - for shape_obj in cvat_anno.labeled_shapes: + for shape_obj in cvat_frame_anno.labeled_shapes: anno_group = shape_obj.group anno_label = map_label(shape_obj.label) anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes) @@ -183,8 +160,64 @@ def convert_attrs(label, cvat_attrs): anno = datumaro.Bbox(x0, y0, x1 - x0, y1 - y0, label=anno_label, attributes=anno_attr, group=anno_group) else: - raise Exception("Unknown shape type '%s'" % (shape_obj.type)) + raise Exception("Unknown shape type '%s'" % shape_obj.type) item_anno.append(anno) - return item_anno \ No newline at end of file + return item_anno + + +class CvatTaskExtractor(CvatAnnotationsExtractor): + def __init__(self, url, db_task, user): + cvat_annotations = TaskAnnotation(db_task.id, user) + with transaction.atomic(): + cvat_annotations.init_from_db() + cvat_annotations = Annotation(cvat_annotations.ir_data, db_task) + super().__init__(url, cvat_annotations) + + +def match_frame(item, cvat_task_anno): + frame_number = None + if frame_number is None: + try: + frame_number = cvat_task_anno.match_frame(item.id) + except Exception: + pass + if frame_number is None and item.has_image: + try: + frame_number = cvat_task_anno.match_frame(item.image.filename) + except Exception: + pass + if frame_number is None: + try: + frame_number = int(item.id) + except Exception: + pass + if not frame_number in cvat_task_anno.frame_info: + raise Exception("Could not match item id: '%s' with any task frame" % + item.id) + return frame_number + +def import_dm_annotations(dm_dataset, cvat_task_anno): + shapes = { + datumaro.AnnotationType.bbox: ShapeType.RECTANGLE, + datumaro.AnnotationType.polygon: ShapeType.POLYGON, + datumaro.AnnotationType.polyline: ShapeType.POLYLINE, + datumaro.AnnotationType.points: ShapeType.POINTS, + } + + label_cat = dm_dataset.categories()[datumaro.AnnotationType.label] + + for item in dm_dataset: + frame_number = match_frame(item, cvat_task_anno) + + for ann in item.annotations: + if ann.type in shapes: + cvat_task_anno.add_shape(cvat_task_anno.LabeledShape( + type=shapes[ann.type], + frame=frame_number, + label=label_cat.items[ann.label].name, + points=ann.points, + occluded=False, + attributes=[], + )) \ No newline at end of file