Skip to content

Commit

Permalink
SDK: make the dataset cache directory customizable (#5535)
Browse files Browse the repository at this point in the history
This is useful for people whose home directory is too small/not fast
enough. It also lets us make the tests less hacky.
  • Loading branch information
SpecLad authored Dec 29, 2022
1 parent 72b6125 commit 4d32c3c
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[2.4.0] - Unreleased
### Added
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>)
- \[SDK\] Configuration setting to change the dataset cache directory
(<https://github.com/opencv/cvat/pull/5535>)

### Changed
- The Docker Compose files now use the Compose Specification version
Expand Down
7 changes: 7 additions & 0 deletions cvat-sdk/cvat_sdk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import logging
import urllib.parse
from contextlib import suppress
from pathlib import Path
from time import sleep
from typing import Any, Dict, Optional, Sequence, Tuple

import appdirs
import attrs
import packaging.version as pv
import urllib3
Expand All @@ -27,6 +29,8 @@
from cvat_sdk.core.proxies.users import UsersRepo
from cvat_sdk.version import VERSION

_DEFAULT_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))


@attrs.define
class Config:
Expand All @@ -43,6 +47,9 @@ class Config:
verify_ssl: Optional[bool] = None
"""Whether to verify host SSL certificate or not"""

cache_dir: Path = attrs.field(converter=Path, default=_DEFAULT_CACHE_DIR)
"""Directory in which to store cached server data"""


class Client:
"""
Expand Down
5 changes: 1 addition & 4 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import types
import zipfile
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import (
Callable,
Dict,
Expand All @@ -20,7 +19,6 @@
TypeVar,
)

import appdirs
import attrs
import attrs.validators
import PIL.Image
Expand All @@ -36,7 +34,6 @@

_ModelType = TypeVar("_ModelType")

_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
_NUM_DOWNLOAD_THREADS = 4


Expand Down Expand Up @@ -139,7 +136,7 @@ def __init__(
server_dir_name = (
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
)
server_dir = _CACHE_DIR / f"servers/{server_dir_name}"
server_dir = client.config.cache_dir / f"servers/{server_dir_name}"

self._task_dir = server_dir / f"tasks/{self._task.id}"
self._initialize_task_dir()
Expand Down
2 changes: 1 addition & 1 deletion cvat-sdk/gen/templates/openapi-generator/setup.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ setup(
python_requires="{{{generatorLanguageVersion}}}",
install_requires=BASE_REQUIREMENTS,
extras_require={
"pytorch": ['appdirs', 'torch', 'torchvision'],
"pytorch": ['torch', 'torchvision'],
},
package_dir={"": "."},
packages=find_packages(include=["cvat_sdk*"]),
Expand Down
1 change: 1 addition & 0 deletions cvat-sdk/gen/templates/requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
-r api_client.txt

appdirs
attrs >= 21.4.0
packaging >= 21.3
Pillow >= 9.0.1
Expand Down
7 changes: 4 additions & 3 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class TestTaskVisionDataset:
@pytest.fixture(autouse=True)
def setup(
self,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
Expand All @@ -41,13 +40,12 @@ def setup(
self.stdout = fxt_stdout
self.client, self.user = fxt_login
self.client.logger = logger
self.client.config.cache_dir = tmp_path / "cache"

api_client = self.client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger

monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache")

self._create_task()

yield
Expand Down Expand Up @@ -107,6 +105,9 @@ def _create_task(self):
def test_basic(self):
dataset = cvatpt.TaskVisionDataset(self.client, self.task.id)

# verify that the cache is not empty
assert list(self.client.config.cache_dir.iterdir())

assert len(dataset) == self.task.size

for index, (sample_image, sample_target) in enumerate(dataset):
Expand Down

0 comments on commit 4d32c3c

Please sign in to comment.