From 68354f6dfdc5ecf2eecf65633c6b23a385a2d424 Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Thu, 29 Dec 2022 17:41:00 +0300 Subject: [PATCH] SDK: make the dataset cache directory customizable --- CHANGELOG.md | 2 ++ cvat-sdk/cvat_sdk/core/client.py | 7 +++++++ cvat-sdk/cvat_sdk/pytorch/__init__.py | 5 +---- cvat-sdk/gen/templates/openapi-generator/setup.mustache | 2 +- cvat-sdk/gen/templates/requirements/base.txt | 1 + tests/python/sdk/test_pytorch.py | 7 ++++--- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b5d75a2f457..2e517700ad42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## \[2.4.0] - Unreleased ### Added - Filename pattern to simplify uploading cloud storage data for a task () +- \[SDK\] Configuration setting to change the dataset cache directory + () ### Changed - The Docker Compose files now use the Compose Specification version diff --git a/cvat-sdk/cvat_sdk/core/client.py b/cvat-sdk/cvat_sdk/core/client.py index a68fb9840773..04169ced5d21 100644 --- a/cvat-sdk/cvat_sdk/core/client.py +++ b/cvat-sdk/cvat_sdk/core/client.py @@ -8,9 +8,11 @@ import logging import urllib.parse from contextlib import suppress +from pathlib import Path from time import sleep from typing import Any, Dict, Optional, Sequence, Tuple +import appdirs import attrs import packaging.version as pv import urllib3 @@ -27,6 +29,8 @@ from cvat_sdk.core.proxies.users import UsersRepo from cvat_sdk.version import VERSION +_DEFAULT_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai")) + @attrs.define class Config: @@ -43,6 +47,9 @@ class Config: verify_ssl: Optional[bool] = None """Whether to verify host SSL certificate or not""" + cache_dir: Path = attrs.field(converter=Path, default=_DEFAULT_CACHE_DIR) + """Directory in which to store cached server data""" + class Client: """ diff --git a/cvat-sdk/cvat_sdk/pytorch/__init__.py b/cvat-sdk/cvat_sdk/pytorch/__init__.py index fa6b38a00623..9bd24201c8b7 100644 --- a/cvat-sdk/cvat_sdk/pytorch/__init__.py +++ b/cvat-sdk/cvat_sdk/pytorch/__init__.py @@ -6,7 +6,6 @@ import types import zipfile from concurrent.futures import ThreadPoolExecutor -from pathlib import Path from typing import ( Callable, Dict, @@ -20,7 +19,6 @@ TypeVar, ) -import appdirs import attrs import attrs.validators import PIL.Image @@ -36,7 +34,6 @@ _ModelType = TypeVar("_ModelType") -_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai")) _NUM_DOWNLOAD_THREADS = 4 @@ -139,7 +136,7 @@ def __init__( server_dir_name = ( base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode() ) - server_dir = _CACHE_DIR / f"servers/{server_dir_name}" + server_dir = client.config.cache_dir / f"servers/{server_dir_name}" self._task_dir = server_dir / f"tasks/{self._task.id}" self._initialize_task_dir() diff --git a/cvat-sdk/gen/templates/openapi-generator/setup.mustache b/cvat-sdk/gen/templates/openapi-generator/setup.mustache index 13c2a9535966..eb89f5d20554 100644 --- a/cvat-sdk/gen/templates/openapi-generator/setup.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/setup.mustache @@ -77,7 +77,7 @@ setup( python_requires="{{{generatorLanguageVersion}}}", install_requires=BASE_REQUIREMENTS, extras_require={ - "pytorch": ['appdirs', 'torch', 'torchvision'], + "pytorch": ['torch', 'torchvision'], }, package_dir={"": "."}, packages=find_packages(include=["cvat_sdk*"]), diff --git a/cvat-sdk/gen/templates/requirements/base.txt b/cvat-sdk/gen/templates/requirements/base.txt index ffc88d7e7eff..bfb1e723a6d5 100644 --- a/cvat-sdk/gen/templates/requirements/base.txt +++ b/cvat-sdk/gen/templates/requirements/base.txt @@ -1,5 +1,6 @@ -r api_client.txt +appdirs attrs >= 21.4.0 packaging >= 21.3 Pillow >= 9.0.1 diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 1aa61174cd5f..77cd6ecd22f2 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -30,7 +30,6 @@ class TestTaskVisionDataset: @pytest.fixture(autouse=True) def setup( self, - monkeypatch: pytest.MonkeyPatch, tmp_path: Path, fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], @@ -41,13 +40,12 @@ def setup( self.stdout = fxt_stdout self.client, self.user = fxt_login self.client.logger = logger + self.client.config.cache_dir = tmp_path / "cache" api_client = self.client.api_client for k in api_client.configuration.logger: api_client.configuration.logger[k] = logger - monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache") - self._create_task() yield @@ -107,6 +105,9 @@ def _create_task(self): def test_basic(self): dataset = cvatpt.TaskVisionDataset(self.client, self.task.id) + # verify that the cache is not empty + assert list(self.client.config.cache_dir.iterdir()) + assert len(dataset) == self.task.size for index, (sample_image, sample_target) in enumerate(dataset):