diff --git a/CHANGELOG.md b/CHANGELOG.md index a08bcdf77f02..d5e870c70514 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## \[Unreleased] ### Added - Multi-line text attributes supported () +- \{SDK\] `cvat_sdk.datasets`, a framework-agnostic equivalent of `cvat_sdk.pytorch` + () ### Changed - TDB diff --git a/cvat-sdk/cvat_sdk/datasets/__init__.py b/cvat-sdk/cvat_sdk/datasets/__init__.py new file mode 100644 index 000000000000..08dd89165eac --- /dev/null +++ b/cvat-sdk/cvat_sdk/datasets/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from .caching import UpdatePolicy +from .common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError +from .task_dataset import TaskDataset diff --git a/cvat-sdk/cvat_sdk/pytorch/caching.py b/cvat-sdk/cvat_sdk/datasets/caching.py similarity index 100% rename from cvat-sdk/cvat_sdk/pytorch/caching.py rename to cvat-sdk/cvat_sdk/datasets/caching.py diff --git a/cvat-sdk/cvat_sdk/datasets/common.py b/cvat-sdk/cvat_sdk/datasets/common.py new file mode 100644 index 000000000000..2b8269dbd567 --- /dev/null +++ b/cvat-sdk/cvat_sdk/datasets/common.py @@ -0,0 +1,57 @@ +# Copyright (C) 2022-2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import abc +from typing import List + +import attrs +import attrs.validators +import PIL.Image + +import cvat_sdk.core +import cvat_sdk.core.exceptions +import cvat_sdk.models as models + + +class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException): + pass + + +@attrs.frozen +class FrameAnnotations: + """ + Contains annotations that pertain to a single frame. + """ + + tags: List[models.LabeledImage] = attrs.Factory(list) + shapes: List[models.LabeledShape] = attrs.Factory(list) + + +class MediaElement(metaclass=abc.ABCMeta): + """ + The media part of a dataset sample. + """ + + @abc.abstractmethod + def load_image(self) -> PIL.Image.Image: + """ + Loads the media data and returns it as a PIL Image object. + """ + ... + + +@attrs.frozen +class Sample: + """ + Represents an element of a dataset. + """ + + frame_index: int + """Index of the corresponding frame in its task.""" + + annotations: FrameAnnotations + """Annotations belonging to the frame.""" + + media: MediaElement + """Media data of the frame.""" diff --git a/cvat-sdk/cvat_sdk/datasets/task_dataset.py b/cvat-sdk/cvat_sdk/datasets/task_dataset.py new file mode 100644 index 000000000000..586070457934 --- /dev/null +++ b/cvat-sdk/cvat_sdk/datasets/task_dataset.py @@ -0,0 +1,164 @@ +# Copyright (C) 2022-2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import zipfile +from concurrent.futures import ThreadPoolExecutor +from typing import Sequence + +import PIL.Image + +import cvat_sdk.core +import cvat_sdk.core.exceptions +import cvat_sdk.models as models +from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager +from cvat_sdk.datasets.common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError + +_NUM_DOWNLOAD_THREADS = 4 + + +class TaskDataset: + """ + Represents a task on a CVAT server as a collection of samples. + + Each sample corresponds to one frame in the task, and provides access to + the corresponding annotations and media data. Deleted frames are omitted. + + This class caches all data and annotations for the task on the local file system + during construction. + + Limitations: + + * Only tasks with image (not video) data are supported at the moment. + * Track annotations are currently not accessible. + """ + + class _TaskMediaElement(MediaElement): + def __init__(self, dataset: TaskDataset, frame_index: int) -> None: + self._dataset = dataset + self._frame_index = frame_index + + def load_image(self) -> PIL.Image.Image: + return self._dataset._load_frame_image(self._frame_index) + + def __init__( + self, + client: cvat_sdk.core.Client, + task_id: int, + *, + update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE, + ) -> None: + """ + Creates a dataset corresponding to the task with ID `task_id` on the + server that `client` is connected to. + + `update_policy` determines when and if the local cache will be updated. + """ + + self._logger = client.logger + + cache_manager = make_cache_manager(client, update_policy) + self._task = cache_manager.retrieve_task(task_id) + + if not self._task.size or not self._task.data_chunk_size: + raise UnsupportedDatasetError("The task has no data") + + if self._task.data_original_chunk_type != "imageset": + raise UnsupportedDatasetError( + f"{self.__class__.__name__} only supports tasks with image chunks;" + f" current chunk type is {self._task.data_original_chunk_type!r}" + ) + + self._logger.info("Fetching labels...") + self._labels = tuple(self._task.get_labels()) + + data_meta = cache_manager.ensure_task_model( + self._task.id, + "data_meta.json", + models.DataMetaRead, + self._task.get_meta, + "data metadata", + ) + + active_frame_indexes = set(range(self._task.size)) - set(data_meta.deleted_frames) + + self._logger.info("Downloading chunks...") + + self._chunk_dir = cache_manager.chunk_dir(task_id) + self._chunk_dir.mkdir(exist_ok=True, parents=True) + + needed_chunks = {index // self._task.data_chunk_size for index in active_frame_indexes} + + with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool: + + def ensure_chunk(chunk_index): + cache_manager.ensure_chunk(self._task, chunk_index) + + for _ in pool.map(ensure_chunk, sorted(needed_chunks)): + # just need to loop through all results so that any exceptions are propagated + pass + + self._logger.info("All chunks downloaded") + + annotations = cache_manager.ensure_task_model( + self._task.id, + "annotations.json", + models.LabeledData, + self._task.get_annotations, + "annotations", + ) + + self._frame_annotations = { + frame_index: FrameAnnotations() for frame_index in sorted(active_frame_indexes) + } + + for tag in annotations.tags: + # Some annotations may belong to deleted frames; skip those. + if tag.frame in self._frame_annotations: + self._frame_annotations[tag.frame].tags.append(tag) + + for shape in annotations.shapes: + if shape.frame in self._frame_annotations: + self._frame_annotations[shape.frame].shapes.append(shape) + + # TODO: tracks? + + self._samples = [ + Sample(frame_index=k, annotations=v, media=self._TaskMediaElement(self, k)) + for k, v in self._frame_annotations.items() + ] + + @property + def labels(self) -> Sequence[models.ILabel]: + """ + Returns the labels configured in the task. + + Clients must not modify the object returned by this property or its components. + """ + return self._labels + + @property + def samples(self) -> Sequence[Sample]: + """ + Returns a sequence of all samples, in order of their frame indices. + + Note that the frame indices may not be contiguous, as deleted frames will not be included. + + Clients must not modify the object returned by this property or its components. + """ + return self._samples + + def _load_frame_image(self, frame_index: int) -> PIL.Image: + assert frame_index in self._frame_annotations + + chunk_index = frame_index // self._task.data_chunk_size + member_index = frame_index % self._task.data_chunk_size + + with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip: + with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member: + image = PIL.Image.open(chunk_member) + image.load() + + return image diff --git a/cvat-sdk/cvat_sdk/pytorch/__init__.py b/cvat-sdk/cvat_sdk/pytorch/__init__.py index ba6609b268a4..3fa537ff99c0 100644 --- a/cvat-sdk/cvat_sdk/pytorch/__init__.py +++ b/cvat-sdk/cvat_sdk/pytorch/__init__.py @@ -2,8 +2,12 @@ # # SPDX-License-Identifier: MIT -from .caching import UpdatePolicy -from .common import FrameAnnotations, Target, UnsupportedDatasetError +from .common import Target from .project_dataset import ProjectVisionDataset from .task_dataset import TaskVisionDataset from .transforms import ExtractBoundingBoxes, ExtractSingleLabelIndex, LabeledBoxes + +# isort: split +# Compatibility imports +from ..datasets.caching import UpdatePolicy +from ..datasets.common import FrameAnnotations, UnsupportedDatasetError diff --git a/cvat-sdk/cvat_sdk/pytorch/common.py b/cvat-sdk/cvat_sdk/pytorch/common.py index ac5d8fb7ad96..97ef38bc33a8 100644 --- a/cvat-sdk/cvat_sdk/pytorch/common.py +++ b/cvat-sdk/cvat_sdk/pytorch/common.py @@ -2,28 +2,11 @@ # # SPDX-License-Identifier: MIT -from typing import List, Mapping +from typing import Mapping import attrs -import attrs.validators -import cvat_sdk.core -import cvat_sdk.core.exceptions -import cvat_sdk.models as models - - -class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException): - pass - - -@attrs.frozen -class FrameAnnotations: - """ - Contains annotations that pertain to a single frame. - """ - - tags: List[models.LabeledImage] = attrs.Factory(list) - shapes: List[models.LabeledShape] = attrs.Factory(list) +from cvat_sdk.datasets.common import FrameAnnotations @attrs.frozen diff --git a/cvat-sdk/cvat_sdk/pytorch/project_dataset.py b/cvat-sdk/cvat_sdk/pytorch/project_dataset.py index be834b1cedd9..ada554ee1210 100644 --- a/cvat-sdk/cvat_sdk/pytorch/project_dataset.py +++ b/cvat-sdk/cvat_sdk/pytorch/project_dataset.py @@ -12,7 +12,7 @@ import cvat_sdk.core import cvat_sdk.core.exceptions import cvat_sdk.models as models -from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager +from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager from cvat_sdk.pytorch.task_dataset import TaskVisionDataset diff --git a/cvat-sdk/cvat_sdk/pytorch/task_dataset.py b/cvat-sdk/cvat_sdk/pytorch/task_dataset.py index 6edd3ec24aa2..8964d2db47db 100644 --- a/cvat-sdk/cvat_sdk/pytorch/task_dataset.py +++ b/cvat-sdk/cvat_sdk/pytorch/task_dataset.py @@ -2,21 +2,17 @@ # # SPDX-License-Identifier: MIT -import collections import os import types -import zipfile -from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Dict, Mapping, Optional +from typing import Callable, Mapping, Optional -import PIL.Image import torchvision.datasets import cvat_sdk.core import cvat_sdk.core.exceptions -import cvat_sdk.models as models -from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager -from cvat_sdk.pytorch.common import FrameAnnotations, Target, UnsupportedDatasetError +from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager +from cvat_sdk.datasets.task_dataset import TaskDataset +from cvat_sdk.pytorch.common import Target _NUM_DOWNLOAD_THREADS = 4 @@ -75,92 +71,31 @@ def __init__( `update_policy` determines when and if the local cache will be updated. """ - self._logger = client.logger + self._underlying = TaskDataset(client, task_id, update_policy=update_policy) cache_manager = make_cache_manager(client, update_policy) - self._task = cache_manager.retrieve_task(task_id) - - if not self._task.size or not self._task.data_chunk_size: - raise UnsupportedDatasetError("The task has no data") - - if self._task.data_original_chunk_type != "imageset": - raise UnsupportedDatasetError( - f"{self.__class__.__name__} only supports tasks with image chunks;" - f" current chunk type is {self._task.data_original_chunk_type!r}" - ) super().__init__( - os.fspath(cache_manager.task_dir(self._task.id)), + os.fspath(cache_manager.task_dir(task_id)), transforms=transforms, transform=transform, target_transform=target_transform, ) - data_meta = cache_manager.ensure_task_model( - self._task.id, - "data_meta.json", - models.DataMetaRead, - self._task.get_meta, - "data metadata", - ) - self._active_frame_indexes = sorted( - set(range(self._task.size)) - set(data_meta.deleted_frames) - ) - - self._logger.info("Downloading chunks...") - - self._chunk_dir = cache_manager.chunk_dir(task_id) - self._chunk_dir.mkdir(exist_ok=True, parents=True) - - needed_chunks = { - index // self._task.data_chunk_size for index in self._active_frame_indexes - } - - with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool: - - def ensure_chunk(chunk_index): - cache_manager.ensure_chunk(self._task, chunk_index) - - for _ in pool.map(ensure_chunk, sorted(needed_chunks)): - # just need to loop through all results so that any exceptions are propagated - pass - - self._logger.info("All chunks downloaded") - if label_name_to_index is None: self._label_id_to_index = types.MappingProxyType( { label.id: label_index for label_index, label in enumerate( - sorted(self._task.get_labels(), key=lambda l: l.id) + sorted(self._underlying.labels, key=lambda l: l.id) ) } ) else: self._label_id_to_index = types.MappingProxyType( - {label.id: label_name_to_index[label.name] for label in self._task.get_labels()} + {label.id: label_name_to_index[label.name] for label in self._underlying.labels} ) - annotations = cache_manager.ensure_task_model( - self._task.id, - "annotations.json", - models.LabeledData, - self._task.get_annotations, - "annotations", - ) - - self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict( - FrameAnnotations - ) - - for tag in annotations.tags: - self._frame_annotations[tag.frame].tags.append(tag) - - for shape in annotations.shapes: - self._frame_annotations[shape.frame].shapes.append(shape) - - # TODO: tracks? - def __getitem__(self, sample_index: int): """ Returns the sample with index `sample_index`. @@ -168,19 +103,10 @@ def __getitem__(self, sample_index: int): `sample_index` must satisfy the condition `0 <= sample_index < len(self)`. """ - frame_index = self._active_frame_indexes[sample_index] - chunk_index = frame_index // self._task.data_chunk_size - member_index = frame_index % self._task.data_chunk_size + sample = self._underlying.samples[sample_index] - with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip: - with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member: - sample_image = PIL.Image.open(chunk_member) - sample_image.load() - - sample_target = Target( - annotations=self._frame_annotations[frame_index], - label_id_to_index=self._label_id_to_index, - ) + sample_image = sample.media.load_image() + sample_target = Target(sample.annotations, self._label_id_to_index) if self.transforms: sample_image, sample_target = self.transforms(sample_image, sample_target) @@ -188,4 +114,4 @@ def __getitem__(self, sample_index: int): def __len__(self) -> int: """Returns the number of samples in the dataset.""" - return len(self._active_frame_indexes) + return len(self._underlying.samples) diff --git a/cvat-sdk/cvat_sdk/pytorch/transforms.py b/cvat-sdk/cvat_sdk/pytorch/transforms.py index 259ebc045375..d63fdba65f68 100644 --- a/cvat-sdk/cvat_sdk/pytorch/transforms.py +++ b/cvat-sdk/cvat_sdk/pytorch/transforms.py @@ -10,7 +10,8 @@ import torch.utils.data from typing_extensions import TypedDict -from cvat_sdk.pytorch.common import Target, UnsupportedDatasetError +from cvat_sdk.datasets.common import UnsupportedDatasetError +from cvat_sdk.pytorch.common import Target @attrs.frozen diff --git a/tests/python/sdk/test_datasets.py b/tests/python/sdk/test_datasets.py new file mode 100644 index 000000000000..67204e4c26c9 --- /dev/null +++ b/tests/python/sdk/test_datasets.py @@ -0,0 +1,207 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import io +from logging import Logger +from pathlib import Path +from typing import Tuple + +import cvat_sdk.datasets as cvatds +import PIL.Image +import pytest +from cvat_sdk import Client, models +from cvat_sdk.core.proxies.tasks import ResourceType + +from shared.utils.helpers import generate_image_files + +from .util import restrict_api_requests + + +@pytest.fixture(autouse=True) +def _common_setup( + tmp_path: Path, + fxt_login: Tuple[Client, str], + fxt_logger: Tuple[Logger, io.StringIO], +): + logger = fxt_logger[0] + client = fxt_login[0] + client.logger = logger + client.config.cache_dir = tmp_path / "cache" + + api_client = client.api_client + for k in api_client.configuration.logger: + api_client.configuration.logger[k] = logger + + +class TestTaskDataset: + @pytest.fixture(autouse=True) + def setup( + self, + tmp_path: Path, + fxt_login: Tuple[Client, str], + ): + self.client = fxt_login[0] + self.images = generate_image_files(10) + + image_dir = tmp_path / "images" + image_dir.mkdir() + + image_paths = [] + for image in self.images: + image_path = image_dir / image.name + image_path.write_bytes(image.getbuffer()) + image_paths.append(image_path) + + self.task = self.client.tasks.create_from_data( + models.TaskWriteRequest( + "Dataset layer test task", + labels=[ + models.PatchedLabelRequest(name="person"), + models.PatchedLabelRequest(name="car"), + ], + ), + resource_type=ResourceType.LOCAL, + resources=image_paths, + data_params={"chunk_size": 3}, + ) + + self.expected_labels = sorted(self.task.get_labels(), key=lambda l: l.id) + + self.task.update_annotations( + models.PatchedLabeledDataRequest( + tags=[ + models.LabeledImageRequest(frame=8, label_id=self.expected_labels[0].id), + models.LabeledImageRequest(frame=8, label_id=self.expected_labels[1].id), + ], + shapes=[ + models.LabeledShapeRequest( + frame=6, + label_id=self.expected_labels[1].id, + type=models.ShapeType("rectangle"), + points=[1.0, 2.0, 3.0, 4.0], + ), + ], + ) + ) + + def test_basic(self): + dataset = cvatds.TaskDataset(self.client, self.task.id) + + # verify that the cache is not empty + assert list(self.client.config.cache_dir.iterdir()) + + for expected_label, actual_label in zip( + self.expected_labels, sorted(dataset.labels, key=lambda l: l.id) + ): + assert expected_label.id == actual_label.id + assert expected_label.name == actual_label.name + + assert len(dataset.samples) == self.task.size + + for index, sample in enumerate(dataset.samples): + assert sample.frame_index == index + + actual_image = sample.media.load_image() + expected_image = PIL.Image.open(self.images[index]) + + assert actual_image == expected_image + + assert not dataset.samples[0].annotations.tags + assert not dataset.samples[1].annotations.shapes + + assert {tag.label_id for tag in dataset.samples[8].annotations.tags} == { + label.id for label in self.expected_labels + } + assert not dataset.samples[8].annotations.shapes + + assert not dataset.samples[6].annotations.tags + assert len(dataset.samples[6].annotations.shapes) == 1 + assert dataset.samples[6].annotations.shapes[0].type.value == "rectangle" + assert dataset.samples[6].annotations.shapes[0].points == [1.0, 2.0, 3.0, 4.0] + + def test_deleted_frame(self): + self.task.remove_frames_by_ids([1]) + + dataset = cvatds.TaskDataset(self.client, self.task.id) + + assert len(dataset.samples) == self.task.size - 1 + + # sample #0 is still frame #0 + assert dataset.samples[0].frame_index == 0 + assert dataset.samples[0].media.load_image() == PIL.Image.open(self.images[0]) + + # sample #1 is now frame #2 + assert dataset.samples[1].frame_index == 2 + assert dataset.samples[1].media.load_image() == PIL.Image.open(self.images[2]) + + # sample #5 is now frame #6 + assert dataset.samples[5].frame_index == 6 + assert dataset.samples[5].media.load_image() == PIL.Image.open(self.images[6]) + assert len(dataset.samples[5].annotations.shapes) == 1 + + def test_offline(self, monkeypatch: pytest.MonkeyPatch): + dataset = cvatds.TaskDataset( + self.client, + self.task.id, + update_policy=cvatds.UpdatePolicy.IF_MISSING_OR_STALE, + ) + + fresh_samples = list(dataset.samples) + + restrict_api_requests(monkeypatch) + + dataset = cvatds.TaskDataset( + self.client, + self.task.id, + update_policy=cvatds.UpdatePolicy.NEVER, + ) + + cached_samples = list(dataset.samples) + + for fresh_sample, cached_sample in zip(fresh_samples, cached_samples): + assert fresh_sample.frame_index == cached_sample.frame_index + assert fresh_sample.annotations == cached_sample.annotations + assert fresh_sample.media.load_image() == cached_sample.media.load_image() + + def test_update(self, monkeypatch: pytest.MonkeyPatch): + dataset = cvatds.TaskDataset( + self.client, + self.task.id, + ) + + # Recreating the dataset should only result in minimal requests. + restrict_api_requests( + monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"} + ) + + dataset = cvatds.TaskDataset( + self.client, + self.task.id, + ) + + assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[1].id + + # After an update, the annotations should be redownloaded. + monkeypatch.undo() + + self.task.update_annotations( + models.PatchedLabeledDataRequest( + shapes=[ + models.LabeledShapeRequest( + id=dataset.samples[6].annotations.shapes[0].id, + frame=6, + label_id=self.expected_labels[0].id, + type=models.ShapeType("rectangle"), + points=[1.0, 2.0, 3.0, 4.0], + ), + ] + ) + ) + + dataset = cvatds.TaskDataset( + self.client, + self.task.id, + ) + + assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[0].id diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index bcfedf7b8b3b..722cb37ab003 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -7,12 +7,10 @@ import os from logging import Logger from pathlib import Path -from typing import Container, Tuple -from urllib.parse import urlparse +from typing import Tuple import pytest from cvat_sdk import Client, models -from cvat_sdk.api_client.rest import RESTClientObject from cvat_sdk.core.proxies.tasks import ResourceType try: @@ -30,6 +28,8 @@ from shared.utils.helpers import generate_image_files +from .util import restrict_api_requests + @pytest.fixture(autouse=True) def _common_setup( @@ -47,20 +47,6 @@ def _common_setup( api_client.configuration.logger[k] = logger -def _restrict_api_requests( - monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = () -) -> None: - original_request = RESTClientObject.request - - def restricted_request(self, method, url, *args, **kwargs): - parsed_url = urlparse(url) - if parsed_url.path in allow_paths: - return original_request(self, method, url, *args, **kwargs) - raise RuntimeError("Disallowed!") - - monkeypatch.setattr(RESTClientObject, "request", restricted_request) - - @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") class TestTaskVisionDataset: @pytest.fixture(autouse=True) @@ -254,7 +240,7 @@ def test_offline(self, monkeypatch: pytest.MonkeyPatch): fresh_samples = list(dataset) - _restrict_api_requests(monkeypatch) + restrict_api_requests(monkeypatch) dataset = cvatpt.TaskVisionDataset( self.client, @@ -273,7 +259,7 @@ def test_update(self, monkeypatch: pytest.MonkeyPatch): ) # Recreating the dataset should only result in minimal requests. - _restrict_api_requests( + restrict_api_requests( monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"} ) @@ -447,7 +433,7 @@ def test_offline(self, monkeypatch: pytest.MonkeyPatch): fresh_samples = list(dataset) - _restrict_api_requests(monkeypatch) + restrict_api_requests(monkeypatch) dataset = cvatpt.ProjectVisionDataset( self.client, diff --git a/tests/python/sdk/util.py b/tests/python/sdk/util.py index 83e6b10e2908..5861c658111a 100644 --- a/tests/python/sdk/util.py +++ b/tests/python/sdk/util.py @@ -4,8 +4,11 @@ import textwrap from pathlib import Path -from typing import Tuple +from typing import Container, Tuple +from urllib.parse import urlparse +import pytest +from cvat_sdk.api_client.rest import RESTClientObject from cvat_sdk.core.helpers import TqdmProgressReporter from tqdm import tqdm @@ -82,3 +85,17 @@ def generate_coco_anno(image_path: str, image_width: int, image_height: int) -> "image_width": image_width, } ) + + +def restrict_api_requests( + monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = () +) -> None: + original_request = RESTClientObject.request + + def restricted_request(self, method, url, *args, **kwargs): + parsed_url = urlparse(url) + if parsed_url.path in allow_paths: + return original_request(self, method, url, *args, **kwargs) + raise RuntimeError("Disallowed!") + + monkeypatch.setattr(RESTClientObject, "request", restricted_request)