Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't fetch existing annotations in cvat_sdk.auto_annotation.annotate_task #7019

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions changelog.d/20231017_194639_roman_auto_annotate_no_load.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
### Added

- \[SDK\] A parameter to `TaskDataset` that allows you to disable annotation loading
(<https://github.com/opencv/cvat/pull/7019>)

### Fixed
- \[SDK\] `cvat_sdk.auto_annotation.annotate_task` no longer performs
unnecessary fetches of existing annotations
(<https://github.com/opencv/cvat/pull/7019>)
2 changes: 1 addition & 1 deletion cvat-sdk/cvat_sdk/auto_annotation/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def annotate_task(
if pbar is None:
pbar = NullProgressReporter()

dataset = TaskDataset(client, task_id)
dataset = TaskDataset(client, task_id, load_annotations=False)

assert isinstance(function.spec, DetectionFunctionSpec)

Expand Down
10 changes: 7 additions & 3 deletions cvat-sdk/cvat_sdk/datasets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: MIT

import abc
from typing import List
from typing import List, Optional

import attrs
import attrs.validators
Expand Down Expand Up @@ -53,8 +53,12 @@ class Sample:
frame_name: str
"""File name of the frame in its task."""

annotations: FrameAnnotations
"""Annotations belonging to the frame."""
annotations: Optional[FrameAnnotations]
"""
Annotations belonging to the frame.

Will be None if the dataset was created without loading annotations.
"""

media: MediaElement
"""Media data of the frame."""
45 changes: 28 additions & 17 deletions cvat-sdk/cvat_sdk/datasets/task_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import zipfile
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence
from typing import Iterable, Sequence

import PIL.Image

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.caching import CacheManager, UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError

_NUM_DOWNLOAD_THREADS = 4
Expand Down Expand Up @@ -49,12 +49,17 @@ def __init__(
task_id: int,
*,
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
load_annotations: bool = True,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
server that `client` is connected to.

`update_policy` determines when and if the local cache will be updated.

`load_annotations` determines whether annotations will be loaded from
the server. If set to False, the `annotations` field in the samples will
be set to None.
"""

self._logger = client.logger
Expand Down Expand Up @@ -102,6 +107,26 @@ def ensure_chunk(chunk_index):

self._logger.info("All chunks downloaded")

if load_annotations:
self._load_annotations(cache_manager, sorted(active_frame_indexes))
else:
self._frame_annotations = {
frame_index: None for frame_index in sorted(active_frame_indexes)
}

# TODO: tracks?

self._samples = [
Sample(
frame_index=k,
frame_name=data_meta.frames[k].name,
annotations=v,
media=self._TaskMediaElement(self, k),
)
for k, v in self._frame_annotations.items()
]

def _load_annotations(self, cache_manager: CacheManager, frame_indexes: Iterable[int]) -> None:
annotations = cache_manager.ensure_task_model(
self._task.id,
"annotations.json",
Expand All @@ -110,9 +135,7 @@ def ensure_chunk(chunk_index):
"annotations",
)

self._frame_annotations = {
frame_index: FrameAnnotations() for frame_index in sorted(active_frame_indexes)
}
self._frame_annotations = {frame_index: FrameAnnotations() for frame_index in frame_indexes}

for tag in annotations.tags:
# Some annotations may belong to deleted frames; skip those.
Expand All @@ -123,18 +146,6 @@ def ensure_chunk(chunk_index):
if shape.frame in self._frame_annotations:
self._frame_annotations[shape.frame].shapes.append(shape)

# TODO: tracks?

self._samples = [
Sample(
frame_index=k,
frame_name=data_meta.frames[k].name,
annotations=v,
media=self._TaskMediaElement(self, k),
)
for k, v in self._frame_annotations.items()
]

@property
def labels(self) -> Sequence[models.ILabel]:
"""
Expand Down
14 changes: 14 additions & 0 deletions tests/python/sdk/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,17 @@ def test_update(self, monkeypatch: pytest.MonkeyPatch):
)

assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[0].id

def test_no_annotations(self):
dataset = cvatds.TaskDataset(self.client, self.task.id, load_annotations=False)

for index, sample in enumerate(dataset.samples):
assert sample.frame_index == index
assert sample.frame_name == self.images[index].name

actual_image = sample.media.load_image()
expected_image = PIL.Image.open(self.images[index])

assert actual_image == expected_image

assert sample.annotations is None