Skip to content

Commit

Permalink
SDK: add a ProjectVisionDataset class
Browse files Browse the repository at this point in the history
It's the analog of `TaskVisionDataset` for projects.
  • Loading branch information
SpecLad committed Dec 28, 2022
1 parent aa4980e commit f03f0ec
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[2.4.0] - Unreleased
### Added
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>)
- SDK: class to represent a project as a PyTorch dataset
(<https://github.com/opencv/cvat/pull/5523>)

### Changed
- TDB
Expand Down
115 changes: 102 additions & 13 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
import attrs.validators
import PIL.Image
import torch
import torch.utils.data
import torchvision.datasets
from typing_extensions import TypedDict

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.api_client.model_utils import to_json
from cvat_sdk.core.utils import atomic_writer
from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShape, TaskRead

_ModelType = TypeVar("_ModelType")

Expand All @@ -50,8 +51,8 @@ class FrameAnnotations:
Contains annotations that pertain to a single frame.
"""

tags: List[LabeledImage] = attrs.Factory(list)
shapes: List[LabeledShape] = attrs.Factory(list)
tags: List[models.LabeledImage] = attrs.Factory(list)
shapes: List[models.LabeledShape] = attrs.Factory(list)


@attrs.frozen
Expand All @@ -70,6 +71,12 @@ class Target:
"""


def _get_server_dir(client: cvat_sdk.core.Client) -> Path:
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
server_dir_name = base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
return _CACHE_DIR / f"servers/{server_dir_name}"


class TaskVisionDataset(torchvision.datasets.VisionDataset):
"""
Represents a task on a CVAT server as a PyTorch Dataset.
Expand Down Expand Up @@ -135,13 +142,7 @@ def __init__(
f" current chunk type is {self._task.data_original_chunk_type!r}"
)

# Base64-encode the name to avoid FS-unsafe characters (like slashes)
server_dir_name = (
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
)
server_dir = _CACHE_DIR / f"servers/{server_dir_name}"

self._task_dir = server_dir / f"tasks/{self._task.id}"
self._task_dir = _get_server_dir(client) / f"tasks/{self._task.id}"
self._initialize_task_dir()

super().__init__(
Expand All @@ -152,7 +153,7 @@ def __init__(
)

data_meta = self._ensure_model(
"data_meta.json", DataMetaRead, self._task.get_meta, "data metadata"
"data_meta.json", models.DataMetaRead, self._task.get_meta, "data metadata"
)
self._active_frame_indexes = sorted(
set(range(self._task.size)) - set(data_meta.deleted_frames)
Expand Down Expand Up @@ -189,7 +190,7 @@ def __init__(
)

annotations = self._ensure_model(
"annotations.json", LabeledData, self._task.get_annotations, "annotations"
"annotations.json", models.LabeledData, self._task.get_annotations, "annotations"
)

self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict(
Expand All @@ -209,7 +210,7 @@ def _initialize_task_dir(self) -> None:

try:
with open(task_json_path, "rb") as task_json_file:
saved_task = TaskRead._new_from_openapi_data(**json.load(task_json_file))
saved_task = models.TaskRead._new_from_openapi_data(**json.load(task_json_file))
except Exception:
self._logger.info("Task is not yet cached or the cache is corrupted")

Expand Down Expand Up @@ -298,6 +299,94 @@ def __len__(self) -> int:
return len(self._active_frame_indexes)


class ProjectVisionDataset(torchvision.datasets.VisionDataset):
"""
Represents a project on a CVAT server as a PyTorch Dataset.
The dataset contains one sample for each frame of each task in the project
(except for tasks that are filtered out - see the description of `task_filter`
in the constructor). The sequence of samples is formed by concatening sequences
of samples from all included tasks in an arbitrary order that's consistent
between executions. Each task's sequence of samples corresponds to the sequence
of frames on the server.
See `TaskVisionDataset` for information on sample format, caching, and
current limitations.
"""

def __init__(
self,
client: cvat_sdk.core.Client,
project_id: int,
*,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
label_name_to_index: Mapping[str, int] = None,
task_filter: Optional[Callable[[models.ITaskRead], bool]] = None,
) -> None:
"""
Creates a dataset corresponding to the project with ID `project_id` on the
server that `client` is connected to.
`transforms`, `transform` and `target_transforms` are optional transformation
functions; see the documentation for `torchvision.datasets.VisionDataset` for
more information.
See `TaskVisionDataset.__init__` for information on `label_name_to_index`.
`task_filter`, if set to a non-`None` value, determines which of the project's
tasks will be included in the dataset. The filter will be called for each task,
and only tasks for which it returns a true value will be included.
If `task_filter` is set to None. then all of the project's tasks will be included.
"""

self._logger = client.logger

self._logger.info(f"Fetching project {project_id}...")
project = client.projects.retrieve(project_id)

# We don't actually need to save anything to this directory (yet),
# but VisionDataset.__init__ requires a root, so make one.
# It could be useful in the future to store the project data for
# offline-only mode.
project_dir = _get_server_dir(client) / f"projects/{project_id}"
project_dir.mkdir(parents=True, exist_ok=True)

super().__init__(
os.fspath(project_dir),
transforms=transforms,
transform=transform,
target_transform=target_transform,
)

self._logger.info("Fetching project tasks...")
tasks = project.get_tasks()
if task_filter:
tasks = list(filter(task_filter, tasks))
tasks.sort(key=lambda t: t.id) # ensure consistent order between executions

self._underlying = torch.utils.data.ConcatDataset(
[
TaskVisionDataset(client, task.id, label_name_to_index=label_name_to_index)
for task in tasks
]
)

def __getitem__(self, sample_index: int):
sample_image, sample_target = self._underlying[sample_index]

if self.transforms:
sample_image, sample_target = self.transforms(sample_image, sample_target)

return sample_image, sample_target

def __len__(self) -> int:
"""Returns the number of samples in the dataset."""
return len(self._underlying)


@attrs.frozen
class ExtractSingleLabelIndex:
"""
Expand Down
168 changes: 147 additions & 21 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: MIT

import io
import itertools
import os
from logging import Logger
from pathlib import Path
Expand All @@ -25,37 +26,36 @@
from shared.utils.helpers import generate_image_files


@pytest.fixture(autouse=True)
def _common_setup(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
):
logger = fxt_logger[0]
client = fxt_login[0]
client.logger = logger

api_client = client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger

monkeypatch.setattr(cvatpt, "_CACHE_DIR", tmp_path / "cache")


@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestTaskVisionDataset:
@pytest.fixture(autouse=True)
def setup(
self,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
fxt_stdout: io.StringIO,
):
self.tmp_path = tmp_path
logger, self.logger_stream = fxt_logger
self.stdout = fxt_stdout
self.client, self.user = fxt_login
self.client.logger = logger

api_client = self.client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger

monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache")

self._create_task()

yield

def _create_task(self):
self.client = fxt_login[0]
self.images = generate_image_files(10)

image_dir = self.tmp_path / "images"
image_dir = tmp_path / "images"
image_dir.mkdir()

image_paths = []
Expand Down Expand Up @@ -224,3 +224,129 @@ def test_custom_label_mapping(self):
_, target = dataset[5]
assert target.label_id_to_index[label_name_to_id["person"]] == 123
assert target.label_id_to_index[label_name_to_id["car"]] == 456


@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestProjectVisionDataset:
@pytest.fixture(autouse=True)
def setup(
self,
tmp_path: Path,
fxt_login: Tuple[Client, str],
):
self.client = fxt_login[0]

self.project = self.client.projects.create(
models.ProjectWriteRequest(
"PyTorch integration test project",
labels=[
models.PatchedLabelRequest(name="person"),
models.PatchedLabelRequest(name="car"),
],
)
)
self.label_ids = sorted(l.id for l in self.project.labels)

all_images = generate_image_files(9)
self.images_per_task = [all_images[i * 3 : i * 3 + 3] for i in range(3)]

image_dir = tmp_path / "images"
image_dir.mkdir()

image_paths_per_task = []
for images in self.images_per_task:
image_paths = []
for image in images:
image_path = image_dir / image.name
image_path.write_bytes(image.getbuffer())
image_paths.append(image_path)
image_paths_per_task.append(image_paths)

self.tasks = [
self.client.tasks.create_from_data(
models.TaskWriteRequest(
"PyTorch integration test task",
project_id=self.project.id,
subset=subset,
),
ResourceType.LOCAL,
image_paths,
data_params={"image_quality": 70},
)
for subset, image_paths in zip(["Train", "Test", "Val"], image_paths_per_task)
]

# sort both self.tasks and self.images_per_task in the order that ProjectVisionDataset uses
self.tasks, self.images_per_task = zip(
*sorted(zip(self.tasks, self.images_per_task), key=lambda t: t[0].id)
)

for task_id, label_index in ((0, 0), (1, 1), (2, 0)):
self.tasks[task_id].update_annotations(
models.PatchedLabeledDataRequest(
tags=[
models.LabeledImageRequest(
frame=task_id, label_id=self.label_ids[label_index]
),
],
)
)

def test_basic(self):
dataset = cvatpt.ProjectVisionDataset(self.client, self.project.id)

assert len(dataset) == sum(task.size for task in self.tasks)

for sample, image in zip(dataset, itertools.chain.from_iterable(self.images_per_task)):
assert torch.equal(TF.pil_to_tensor(sample[0]), TF.pil_to_tensor(PIL.Image.open(image)))

assert dataset[0][1].annotations.tags[0].label_id == self.label_ids[0]
assert dataset[4][1].annotations.tags[0].label_id == self.label_ids[1]
assert dataset[8][1].annotations.tags[0].label_id == self.label_ids[0]

def test_task_filter(self):
dataset = cvatpt.ProjectVisionDataset(
self.client, self.project.id, task_filter=lambda t: t.subset != self.tasks[0].subset
)

assert len(dataset) == sum(task.size for task in self.tasks[1:])

for sample, image in zip(dataset, itertools.chain.from_iterable(self.images_per_task[1:])):
assert torch.equal(TF.pil_to_tensor(sample[0]), TF.pil_to_tensor(PIL.Image.open(image)))

assert dataset[1][1].annotations.tags[0].label_id == self.label_ids[1]
assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0]

def test_custom_label_mapping(self):
label_name_to_id = {label.name: label.id for label in self.project.labels}

dataset = cvatpt.ProjectVisionDataset(
self.client, self.project.id, label_name_to_index={"person": 123, "car": 456}
)

_, target = dataset[5]
assert target.label_id_to_index[label_name_to_id["person"]] == 123
assert target.label_id_to_index[label_name_to_id["car"]] == 456

def test_separate_transforms(self):
dataset = cvatpt.ProjectVisionDataset(
self.client,
self.project.id,
transform=torchvision.transforms.ToTensor(),
target_transform=cvatpt.ExtractSingleLabelIndex(),
)

assert torch.equal(
dataset[0][0], TF.pil_to_tensor(PIL.Image.open(self.images_per_task[0][0]))
)
assert torch.equal(dataset[0][1], torch.tensor(0))

def test_combined_transforms(self):
dataset = cvatpt.ProjectVisionDataset(
self.client,
self.project.id,
transforms=lambda x, y: (y, x),
)

assert isinstance(dataset[0][0], cvatpt.Target)
assert isinstance(dataset[0][1], PIL.Image.Image)

0 comments on commit f03f0ec

Please sign in to comment.