Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create cvat_sdk.datasets, a framework-agnostic version of cvat_sdk.pytorch #6428

Merged
merged 4 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[Unreleased]
### Added
- Multi-line text attributes supported (<https://github.com/opencv/cvat/pull/6458>)
- \{SDK\] `cvat_sdk.datasets`, a framework-agnostic equivalent of `cvat_sdk.pytorch`
(<https://github.com/opencv/cvat/pull/6428>)

### Changed
- TDB
Expand Down
7 changes: 7 additions & 0 deletions cvat-sdk/cvat_sdk/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from .caching import UpdatePolicy
from .common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError
from .task_dataset import TaskDataset
57 changes: 57 additions & 0 deletions cvat-sdk/cvat_sdk/datasets/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (C) 2022-2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import abc
from typing import List

import attrs
import attrs.validators
import PIL.Image

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models


class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):
pass


@attrs.frozen
class FrameAnnotations:
"""
Contains annotations that pertain to a single frame.
"""

tags: List[models.LabeledImage] = attrs.Factory(list)
shapes: List[models.LabeledShape] = attrs.Factory(list)


class MediaElement(metaclass=abc.ABCMeta):
"""
The media part of a dataset sample.
"""

@abc.abstractmethod
def load_image(self) -> PIL.Image.Image:
"""
Loads the media data and returns it as a PIL Image object.
"""
...


@attrs.frozen
class Sample:
"""
Represents an element of a dataset.
"""

frame_index: int
"""Index of the corresponding frame in its task."""

annotations: FrameAnnotations
"""Annotations belonging to the frame."""

media: MediaElement
"""Media data of the frame."""
164 changes: 164 additions & 0 deletions cvat-sdk/cvat_sdk/datasets/task_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (C) 2022-2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from __future__ import annotations

import zipfile
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence

import PIL.Image

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError

_NUM_DOWNLOAD_THREADS = 4


class TaskDataset:
"""
Represents a task on a CVAT server as a collection of samples.

Each sample corresponds to one frame in the task, and provides access to
the corresponding annotations and media data. Deleted frames are omitted.

This class caches all data and annotations for the task on the local file system
during construction.

Limitations:

* Only tasks with image (not video) data are supported at the moment.
* Track annotations are currently not accessible.
"""

class _TaskMediaElement(MediaElement):
def __init__(self, dataset: TaskDataset, frame_index: int) -> None:
self._dataset = dataset
self._frame_index = frame_index

def load_image(self) -> PIL.Image.Image:
return self._dataset._load_frame_image(self._frame_index)

def __init__(
self,
client: cvat_sdk.core.Client,
task_id: int,
*,
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
server that `client` is connected to.

`update_policy` determines when and if the local cache will be updated.
"""

self._logger = client.logger

cache_manager = make_cache_manager(client, update_policy)
self._task = cache_manager.retrieve_task(task_id)

if not self._task.size or not self._task.data_chunk_size:
raise UnsupportedDatasetError("The task has no data")

if self._task.data_original_chunk_type != "imageset":
raise UnsupportedDatasetError(
f"{self.__class__.__name__} only supports tasks with image chunks;"
f" current chunk type is {self._task.data_original_chunk_type!r}"
)

self._logger.info("Fetching labels...")
self._labels = tuple(self._task.get_labels())

data_meta = cache_manager.ensure_task_model(
self._task.id,
"data_meta.json",
models.DataMetaRead,
self._task.get_meta,
"data metadata",
)

active_frame_indexes = set(range(self._task.size)) - set(data_meta.deleted_frames)

self._logger.info("Downloading chunks...")

self._chunk_dir = cache_manager.chunk_dir(task_id)
self._chunk_dir.mkdir(exist_ok=True, parents=True)

needed_chunks = {index // self._task.data_chunk_size for index in active_frame_indexes}

with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:

def ensure_chunk(chunk_index):
cache_manager.ensure_chunk(self._task, chunk_index)

for _ in pool.map(ensure_chunk, sorted(needed_chunks)):
# just need to loop through all results so that any exceptions are propagated
pass

self._logger.info("All chunks downloaded")

annotations = cache_manager.ensure_task_model(
self._task.id,
"annotations.json",
models.LabeledData,
self._task.get_annotations,
"annotations",
)

self._frame_annotations = {
frame_index: FrameAnnotations() for frame_index in sorted(active_frame_indexes)
}

for tag in annotations.tags:
# Some annotations may belong to deleted frames; skip those.
if tag.frame in self._frame_annotations:
self._frame_annotations[tag.frame].tags.append(tag)

for shape in annotations.shapes:
if shape.frame in self._frame_annotations:
self._frame_annotations[shape.frame].shapes.append(shape)

# TODO: tracks?

self._samples = [
Sample(frame_index=k, annotations=v, media=self._TaskMediaElement(self, k))
for k, v in self._frame_annotations.items()
]

@property
def labels(self) -> Sequence[models.ILabel]:
"""
Returns the labels configured in the task.

Clients must not modify the object returned by this property or its components.
"""
return self._labels

@property
def samples(self) -> Sequence[Sample]:
"""
Returns a sequence of all samples, in order of their frame indices.

Note that the frame indices may not be contiguous, as deleted frames will not be included.

Clients must not modify the object returned by this property or its components.
"""
return self._samples

def _load_frame_image(self, frame_index: int) -> PIL.Image:
assert frame_index in self._frame_annotations

chunk_index = frame_index // self._task.data_chunk_size
member_index = frame_index % self._task.data_chunk_size

with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip:
with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member:
image = PIL.Image.open(chunk_member)
image.load()

return image
8 changes: 6 additions & 2 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
#
# SPDX-License-Identifier: MIT

from .caching import UpdatePolicy
from .common import FrameAnnotations, Target, UnsupportedDatasetError
from .common import Target
from .project_dataset import ProjectVisionDataset
from .task_dataset import TaskVisionDataset
from .transforms import ExtractBoundingBoxes, ExtractSingleLabelIndex, LabeledBoxes

# isort: split
# Compatibility imports
from ..datasets.caching import UpdatePolicy
from ..datasets.common import FrameAnnotations, UnsupportedDatasetError
21 changes: 2 additions & 19 deletions cvat-sdk/cvat_sdk/pytorch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,11 @@
#
# SPDX-License-Identifier: MIT

from typing import List, Mapping
from typing import Mapping

import attrs
import attrs.validators

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models


class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):
pass


@attrs.frozen
class FrameAnnotations:
"""
Contains annotations that pertain to a single frame.
"""

tags: List[models.LabeledImage] = attrs.Factory(list)
shapes: List[models.LabeledShape] = attrs.Factory(list)
from cvat_sdk.datasets.common import FrameAnnotations


@attrs.frozen
Expand Down
2 changes: 1 addition & 1 deletion cvat-sdk/cvat_sdk/pytorch/project_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.pytorch.task_dataset import TaskVisionDataset


Expand Down
Loading