Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK: make the dataset cache directory customizable #5535

Merged
merged 1 commit into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[2.4.0] - Unreleased
### Added
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>)
- \[SDK\] Configuration setting to change the dataset cache directory
(<https://github.com/opencv/cvat/pull/5535>)

### Changed
- The Docker Compose files now use the Compose Specification version
Expand Down
7 changes: 7 additions & 0 deletions cvat-sdk/cvat_sdk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import logging
import urllib.parse
from contextlib import suppress
from pathlib import Path
from time import sleep
from typing import Any, Dict, Optional, Sequence, Tuple

import appdirs
import attrs
import packaging.version as pv
import urllib3
Expand All @@ -27,6 +29,8 @@
from cvat_sdk.core.proxies.users import UsersRepo
from cvat_sdk.version import VERSION

_DEFAULT_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))


@attrs.define
class Config:
Expand All @@ -43,6 +47,9 @@ class Config:
verify_ssl: Optional[bool] = None
"""Whether to verify host SSL certificate or not"""

cache_dir: Path = attrs.field(converter=Path, default=_DEFAULT_CACHE_DIR)
zhiltsov-max marked this conversation as resolved.
Show resolved Hide resolved
"""Directory in which to store cached server data"""


class Client:
"""
Expand Down
5 changes: 1 addition & 4 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import types
import zipfile
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import (
Callable,
Dict,
Expand All @@ -20,7 +19,6 @@
TypeVar,
)

import appdirs
import attrs
import attrs.validators
import PIL.Image
Expand All @@ -36,7 +34,6 @@

_ModelType = TypeVar("_ModelType")

_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
_NUM_DOWNLOAD_THREADS = 4


Expand Down Expand Up @@ -139,7 +136,7 @@ def __init__(
server_dir_name = (
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
)
server_dir = _CACHE_DIR / f"servers/{server_dir_name}"
server_dir = client.config.cache_dir / f"servers/{server_dir_name}"

self._task_dir = server_dir / f"tasks/{self._task.id}"
self._initialize_task_dir()
Expand Down
2 changes: 1 addition & 1 deletion cvat-sdk/gen/templates/openapi-generator/setup.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ setup(
python_requires="{{{generatorLanguageVersion}}}",
install_requires=BASE_REQUIREMENTS,
extras_require={
"pytorch": ['appdirs', 'torch', 'torchvision'],
"pytorch": ['torch', 'torchvision'],
},
package_dir={"": "."},
packages=find_packages(include=["cvat_sdk*"]),
Expand Down
1 change: 1 addition & 0 deletions cvat-sdk/gen/templates/requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
-r api_client.txt

appdirs
attrs >= 21.4.0
packaging >= 21.3
Pillow >= 9.0.1
Expand Down
7 changes: 4 additions & 3 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class TestTaskVisionDataset:
@pytest.fixture(autouse=True)
def setup(
self,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
Expand All @@ -41,13 +40,12 @@ def setup(
self.stdout = fxt_stdout
self.client, self.user = fxt_login
self.client.logger = logger
self.client.config.cache_dir = tmp_path / "cache"

api_client = self.client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger

monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache")

self._create_task()

yield
Expand Down Expand Up @@ -107,6 +105,9 @@ def _create_task(self):
def test_basic(self):
dataset = cvatpt.TaskVisionDataset(self.client, self.task.id)

# verify that the cache is not empty
assert list(self.client.config.cache_dir.iterdir())

assert len(dataset) == self.task.size

for index, (sample_image, sample_target) in enumerate(dataset):
Expand Down