Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK: add a ProjectVisionDataset class #5523

Merged
merged 3 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>)
- \[SDK\] Configuration setting to change the dataset cache directory
(<https://github.com/opencv/cvat/pull/5535>)
- \[SDK\] Class to represent a project as a PyTorch dataset
(<https://github.com/opencv/cvat/pull/5523>)

### Changed
- The Docker Compose files now use the Compose Specification version
Expand Down
144 changes: 119 additions & 25 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,22 @@
import types
import zipfile
from concurrent.futures import ThreadPoolExecutor
from typing import (
Callable,
Dict,
FrozenSet,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from pathlib import Path
from typing import Callable, Container, Dict, FrozenSet, List, Mapping, Optional, Type, TypeVar

import attrs
import attrs.validators
import PIL.Image
import torch
import torch.utils.data
import torchvision.datasets
from typing_extensions import TypedDict

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.api_client.model_utils import to_json
from cvat_sdk.core.utils import atomic_writer
from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShape, TaskRead

_ModelType = TypeVar("_ModelType")

Expand All @@ -47,8 +38,8 @@ class FrameAnnotations:
Contains annotations that pertain to a single frame.
"""

tags: List[LabeledImage] = attrs.Factory(list)
shapes: List[LabeledShape] = attrs.Factory(list)
tags: List[models.LabeledImage] = attrs.Factory(list)
shapes: List[models.LabeledShape] = attrs.Factory(list)


@attrs.frozen
Expand All @@ -67,6 +58,12 @@ class Target:
"""


def _get_server_dir(client: cvat_sdk.core.Client) -> Path:
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
server_dir_name = base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
return client.config.cache_dir / f"servers/{server_dir_name}"


class TaskVisionDataset(torchvision.datasets.VisionDataset):
"""
Represents a task on a CVAT server as a PyTorch Dataset.
Expand Down Expand Up @@ -132,13 +129,7 @@ def __init__(
f" current chunk type is {self._task.data_original_chunk_type!r}"
)

zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
server_dir_name = (
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
)
server_dir = client.config.cache_dir / f"servers/{server_dir_name}"

self._task_dir = server_dir / f"tasks/{self._task.id}"
self._task_dir = _get_server_dir(client) / f"tasks/{self._task.id}"
self._initialize_task_dir()

super().__init__(
Expand All @@ -149,7 +140,7 @@ def __init__(
)

data_meta = self._ensure_model(
"data_meta.json", DataMetaRead, self._task.get_meta, "data metadata"
"data_meta.json", models.DataMetaRead, self._task.get_meta, "data metadata"
)
self._active_frame_indexes = sorted(
set(range(self._task.size)) - set(data_meta.deleted_frames)
Expand Down Expand Up @@ -186,7 +177,7 @@ def __init__(
)

annotations = self._ensure_model(
"annotations.json", LabeledData, self._task.get_annotations, "annotations"
"annotations.json", models.LabeledData, self._task.get_annotations, "annotations"
)

self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict(
Expand All @@ -206,7 +197,7 @@ def _initialize_task_dir(self) -> None:

try:
with open(task_json_path, "rb") as task_json_file:
saved_task = TaskRead._new_from_openapi_data(**json.load(task_json_file))
saved_task = models.TaskRead._new_from_openapi_data(**json.load(task_json_file))
except Exception:
self._logger.info("Task is not yet cached or the cache is corrupted")

Expand Down Expand Up @@ -295,6 +286,109 @@ def __len__(self) -> int:
return len(self._active_frame_indexes)


class ProjectVisionDataset(torchvision.datasets.VisionDataset):
"""
Represents a project on a CVAT server as a PyTorch Dataset.

The dataset contains one sample for each frame of each task in the project
(except for tasks that are filtered out - see the description of `task_filter`
in the constructor). The sequence of samples is formed by concatening sequences
of samples from all included tasks in an arbitrary order that's consistent
between executions. Each task's sequence of samples corresponds to the sequence
of frames on the server.

See `TaskVisionDataset` for information on sample format, caching, and
current limitations.
"""

def __init__(
self,
client: cvat_sdk.core.Client,
project_id: int,
*,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
label_name_to_index: Mapping[str, int] = None,
task_filter: Optional[Callable[[models.ITaskRead], bool]] = None,
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
include_subsets: Optional[Container[str]] = None,
) -> None:
"""
Creates a dataset corresponding to the project with ID `project_id` on the
server that `client` is connected to.

`transforms`, `transform` and `target_transforms` are optional transformation
functions; see the documentation for `torchvision.datasets.VisionDataset` for
more information.

See `TaskVisionDataset.__init__` for information on `label_name_to_index`.

By default, all of the project's tasks will be included in the dataset.
The following parameters can be specified to exclude some tasks:

* If `task_filter` is set to a callable object, it will be applied to every task.
Tasks for which it returns a false value will be excluded.

* If `include_subsets` is set to a container, then tasks whose subset is
not a member of this container will be excluded.
"""

self._logger = client.logger

self._logger.info(f"Fetching project {project_id}...")
project = client.projects.retrieve(project_id)

# We don't actually need to save anything to this directory (yet),
# but VisionDataset.__init__ requires a root, so make one.
# It could be useful in the future to store the project data for
# offline-only mode.
project_dir = _get_server_dir(client) / f"projects/{project_id}"
project_dir.mkdir(parents=True, exist_ok=True)

super().__init__(
os.fspath(project_dir),
transforms=transforms,
transform=transform,
target_transform=target_transform,
)

self._logger.info("Fetching project tasks...")
tasks = project.get_tasks()

if task_filter is not None:
tasks = list(filter(task_filter, tasks))

if include_subsets is not None:
tasks = [task for task in tasks if task.subset in include_subsets]

tasks.sort(key=lambda t: t.id) # ensure consistent order between executions

self._underlying = torch.utils.data.ConcatDataset(
[
TaskVisionDataset(client, task.id, label_name_to_index=label_name_to_index)
for task in tasks
]
)

def __getitem__(self, sample_index: int):
"""
Returns the sample with index `sample_index`.

`sample_index` must satisfy the condition `0 <= sample_index < len(self)`.
"""

sample_image, sample_target = self._underlying[sample_index]

if self.transforms:
sample_image, sample_target = self.transforms(sample_image, sample_target)

return sample_image, sample_target

def __len__(self) -> int:
"""Returns the number of samples in the dataset."""
return len(self._underlying)


@attrs.frozen
class ExtractSingleLabelIndex:
"""
Expand Down
Loading