Skip to content

Commit

Permalink
PyTorch adapter: fix loading tasks that have already been cached (#6396)
Browse files Browse the repository at this point in the history
<!-- Raise an issue to propose your change
(https://github.com/opencv/cvat/issues).
It helps to avoid duplication of efforts from multiple independent
contributors.
Discuss your ideas with maintainers to be sure that changes will be
approved and merged.
Read the [Contribution
guide](https://opencv.github.io/cvat/docs/contributing/). -->

<!-- Provide a general summary of your changes in the Title above -->

### Motivation and context
<!-- Why is this change required? What problem does it solve? If it
fixes an open
issue, please link to the issue here. Describe your changes in detail,
add
screenshots. -->
This was broken in 4fc494f.

Add a test to cover this case.

Fixes #6047.
  • Loading branch information
SpecLad authored Jun 28, 2023
1 parent 1b28162 commit d950d24
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added missed auto_add argument to Issue model (<https://github.com/opencv/cvat/pull/6364>)
- \[API\] Performance of several API endpoints (<https://github.com/opencv/cvat/pull/6340>)
- \[API\] Invalid schema for the owner field in several endpoints (<https://github.com/opencv/cvat/pull/6343>)
- \[SDK\] Loading tasks that have been cached with the PyTorch adapter
(<https://github.com/opencv/cvat/issues/6047>)

### Security
- TDB
Expand Down
2 changes: 1 addition & 1 deletion cvat-sdk/cvat_sdk/pytorch/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _initialize_task_dir(self, task: Task) -> None:
if task_dir.exists():
shutil.rmtree(task_dir)
else:
if saved_task.updated_date < task.updated_date:
if saved_task.api_model.updated_date < task.updated_date:
self._logger.info(
f"Task {task.id} has been updated on the server since it was cached; purging the cache"
)
Expand Down
60 changes: 53 additions & 7 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import os
from logging import Logger
from pathlib import Path
from typing import Tuple
from typing import Container, Tuple
from urllib.parse import urlparse

import pytest
from cvat_sdk import Client, models
Expand Down Expand Up @@ -46,11 +47,18 @@ def _common_setup(
api_client.configuration.logger[k] = logger


def _disable_api_requests(monkeypatch: pytest.MonkeyPatch) -> None:
def disabled_request(*args, **kwargs):
raise RuntimeError("Disabled!")
def _restrict_api_requests(
monkeypatch: pytest.MonkeyPatch, allow_paths: Container[str] = ()
) -> None:
original_request = RESTClientObject.request

monkeypatch.setattr(RESTClientObject, "request", disabled_request)
def restricted_request(self, method, url, *args, **kwargs):
parsed_url = urlparse(url)
if parsed_url.path in allow_paths:
return original_request(self, method, url, *args, **kwargs)
raise RuntimeError("Disallowed!")

monkeypatch.setattr(RESTClientObject, "request", restricted_request)


@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
Expand Down Expand Up @@ -246,7 +254,7 @@ def test_offline(self, monkeypatch: pytest.MonkeyPatch):

fresh_samples = list(dataset)

_disable_api_requests(monkeypatch)
_restrict_api_requests(monkeypatch)

dataset = cvatpt.TaskVisionDataset(
self.client,
Expand All @@ -258,6 +266,44 @@ def test_offline(self, monkeypatch: pytest.MonkeyPatch):

assert fresh_samples == cached_samples

def test_update(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)

# Recreating the dataset should only result in minimal requests.
_restrict_api_requests(
monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"}
)

dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)

assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0]

# After an update, the annotations should be redownloaded.
monkeypatch.undo()

self.task.update_annotations(
models.PatchedLabeledDataRequest(
tags=[
models.LabeledImageRequest(
id=dataset[5][1].annotations.tags[0].id, frame=5, label_id=self.label_ids[1]
),
]
)
)

dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
)

assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[1]


@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestProjectVisionDataset:
Expand Down Expand Up @@ -401,7 +447,7 @@ def test_offline(self, monkeypatch: pytest.MonkeyPatch):

fresh_samples = list(dataset)

_disable_api_requests(monkeypatch)
_restrict_api_requests(monkeypatch)

dataset = cvatpt.ProjectVisionDataset(
self.client,
Expand Down

0 comments on commit d950d24

Please sign in to comment.