diff --git a/cvat-cli/src/cvat_cli/__main__.py b/cvat-cli/src/cvat_cli/__main__.py index 63cb31a0b3a0..ce32b832fdde 100755 --- a/cvat-cli/src/cvat_cli/__main__.py +++ b/cvat-cli/src/cvat_cli/__main__.py @@ -33,10 +33,15 @@ def configure_logger(level): def build_client(parsed_args: SimpleNamespace, logger: logging.Logger) -> Client: config = Config(verify_ssl=not parsed_args.insecure) + url = parsed_args.server_host + if parsed_args.server_port: + url += f":{parsed_args.server_port}" + return Client( - url="{host}:{port}".format(host=parsed_args.server_host, port=parsed_args.server_port), + url=url, logger=logger, config=config, + check_server_version=False, # version is checked after auth to support versions < 2.3 ) diff --git a/cvat-cli/src/cvat_cli/cli.py b/cvat-cli/src/cvat_cli/cli.py index 4609f95491ab..00122341cf58 100644 --- a/cvat-cli/src/cvat_cli/cli.py +++ b/cvat-cli/src/cvat_cli/cli.py @@ -23,6 +23,8 @@ def __init__(self, client: Client, credentials: Tuple[str, str]): self.client.login(credentials) + self.client.check_server_version(fail_if_unsupported=False) + def tasks_list(self, *, use_json_output: bool = False, **kwargs): """List all tasks in either basic or JSON format.""" results = self.client.tasks.list(return_json=use_json_output, **kwargs) diff --git a/cvat-sdk/.gitignore b/cvat-sdk/.gitignore index f27f78919cb0..d01a61d14490 100644 --- a/cvat-sdk/.gitignore +++ b/cvat-sdk/.gitignore @@ -70,10 +70,10 @@ schema/ .openapi-generator/ # Generated code -cvat_sdk/api_client/ -cvat_sdk/version.py -requirements/ -docs/ -setup.py -README.md -MANIFEST.in \ No newline at end of file +/cvat_sdk/api_client/ +/cvat_sdk/version.py +/requirements/ +/docs/ +/setup.py +/README.md +/MANIFEST.in \ No newline at end of file diff --git a/cvat-sdk/cvat_sdk/core/client.py b/cvat-sdk/cvat_sdk/core/client.py index 99033f37e9e8..a60ec63767af 100644 --- a/cvat-sdk/cvat_sdk/core/client.py +++ b/cvat-sdk/cvat_sdk/core/client.py @@ -12,11 +12,12 @@ from typing import Any, Dict, Optional, Sequence, Tuple import attrs +import packaging.version as pv import urllib3 import urllib3.exceptions -from cvat_sdk.api_client import ApiClient, Configuration, models -from cvat_sdk.core.exceptions import InvalidHostException +from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models +from cvat_sdk.core.exceptions import IncompatibleVersionException, InvalidHostException from cvat_sdk.core.helpers import expect_status from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo from cvat_sdk.core.proxies.jobs import JobsRepo @@ -24,17 +25,19 @@ from cvat_sdk.core.proxies.projects import ProjectsRepo from cvat_sdk.core.proxies.tasks import TasksRepo from cvat_sdk.core.proxies.users import UsersRepo +from cvat_sdk.version import VERSION @attrs.define class Config: status_check_period: float = 5 - """In seconds""" + """Operation status check period, in seconds""" + + allow_unsupported_server: bool = True + """Allow to use SDK with an unsupported server version. If disabled, raise an exception""" verify_ssl: Optional[bool] = None - """ - Whether to verify host SSL certificate or not. - """ + """Whether to verify host SSL certificate or not""" class Client: @@ -42,9 +45,21 @@ class Client: Manages session and configuration. """ + SUPPORTED_SERVER_VERSIONS = ( + pv.Version("2.0"), + pv.Version("2.1"), + pv.Version("2.2"), + pv.Version("2.3"), + ) + def __init__( - self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None - ): + self, + url: str, + *, + logger: Optional[logging.Logger] = None, + config: Optional[Config] = None, + check_server_version: bool = True, + ) -> None: url = self._validate_and_prepare_url(url) self.logger = logger or logging.getLogger(__name__) self.config = config or Config() @@ -53,6 +68,9 @@ def __init__( Configuration(host=self.api_map.host, verify_ssl=self.config.verify_ssl) ) + if check_server_version: + self.check_server_version() + self._repos: Dict[str, Repo] = {} ALLOWED_SCHEMAS = ("https", "http") @@ -87,12 +105,14 @@ def _detect_schema(cls, base_url: str) -> str: _request_timeout=5, _parse_response=False, _check_status=False ) - if response.status == 401: + if response.status in [200, 401]: + # Server versions prior to 2.3.0 respond with unauthorized + # 2.3.0 allows unauthorized access return schema raise InvalidHostException( "Failed to detect host schema automatically, please check " - "the server url and try to specify schema explicitly" + "the server url and try to specify 'https://' or 'http://' explicitly" ) def __enter__(self): @@ -162,6 +182,44 @@ def wait_for_completion( return response + def check_server_version(self, fail_if_unsupported: Optional[bool] = None) -> None: + if fail_if_unsupported is None: + fail_if_unsupported = not self.config.allow_unsupported_server + + try: + server_version = self.get_server_version() + except exceptions.ApiException as e: + msg = ( + "Failed to retrieve server API version: %s. " + "Some SDK functions may not work properly with this server." + ) % (e,) + self.logger.warning(msg) + if fail_if_unsupported: + raise IncompatibleVersionException(msg) + return + + sdk_version = pv.Version(VERSION) + + # We only check base version match. Micro releases and fixes do not affect + # API compatibility in general. + if all( + server_version.base_version != sv.base_version for sv in self.SUPPORTED_SERVER_VERSIONS + ): + msg = ( + "Server version '%s' is not compatible with SDK version '%s'. " + "Some SDK functions may not work properly with this server. " + "You can continue using this SDK, or you can " + "try to update with 'pip install cvat-sdk'." + ) % (server_version, sdk_version) + self.logger.warning(msg) + if fail_if_unsupported: + raise IncompatibleVersionException(msg) + + def get_server_version(self) -> pv.Version: + # TODO: allow to use this endpoint unauthorized + (about, _) = self.api_client.server_api.retrieve_about() + return pv.Version(about.version) + def _get_repo(self, key: str) -> Repo: _repo_map = { "tasks": TasksRepo, diff --git a/cvat-sdk/cvat_sdk/core/exceptions.py b/cvat-sdk/cvat_sdk/core/exceptions.py index c458bf02d102..b90a8fc18f54 100644 --- a/cvat-sdk/cvat_sdk/core/exceptions.py +++ b/cvat-sdk/cvat_sdk/core/exceptions.py @@ -9,3 +9,7 @@ class CvatSdkException(Exception): class InvalidHostException(CvatSdkException): """Indicates an invalid hostname error""" + + +class IncompatibleVersionException(CvatSdkException): + """Indicates server and SDK version mismatch""" diff --git a/cvat-sdk/gen/templates/requirements/base.txt b/cvat-sdk/gen/templates/requirements/base.txt index f22bae3b6fe0..ffc88d7e7eff 100644 --- a/cvat-sdk/gen/templates/requirements/base.txt +++ b/cvat-sdk/gen/templates/requirements/base.txt @@ -1,7 +1,8 @@ -r api_client.txt attrs >= 21.4.0 +packaging >= 21.3 Pillow >= 9.0.1 tqdm >= 4.64.0 tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code -typing_extensions >= 4.2.0 +typing_extensions >= 4.2.0 \ No newline at end of file diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index 43b82ab405c8..a503c811ea46 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -504,7 +504,7 @@ def test_api_v2_server_about_user(self): def test_api_v2_server_about_no_auth(self): response = self._run_api_v2_server_about(None) - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_api_server_about_versions_admin(self): for version in settings.REST_FRAMEWORK['ALLOWED_VERSIONS']: diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 37c2a28e1c92..c00e4f139fbc 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -98,7 +98,9 @@ def get_serializer(self, *args, **kwargs): responses={ '200': AboutSerializer, }) - @action(detail=False, methods=['GET'], serializer_class=AboutSerializer) + @action(detail=False, methods=['GET'], serializer_class=AboutSerializer, + permission_classes=[] # This endpoint is available for everyone + ) def about(request): from cvat import __version__ as cvat_version about = { diff --git a/cvat/settings/base.py b/cvat/settings/base.py index ea252b11ef3e..4a18a9e41b0f 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -23,6 +23,7 @@ import mimetypes from corsheaders.defaults import default_headers from distutils.util import strtobool +from cvat import __version__ mimetypes.add_type("application/wasm", ".wasm", True) @@ -517,7 +518,7 @@ def add_ssh_keys(): # Statically set schema version. May also be an empty string. When used together with # view versioning, will become '0.0.0 (v2)' for 'v2' versioned requests. # Set VERSION to None if only the request version should be rendered. - 'VERSION': '2.1.0', + 'VERSION': __version__, 'CONTACT': { 'name': 'CVAT.ai team', 'url': 'https://github.com/cvat-ai/cvat', diff --git a/tests/python/cli/test_cli.py b/tests/python/cli/test_cli.py index f437f9ba815d..cf1f6a69b095 100644 --- a/tests/python/cli/test_cli.py +++ b/tests/python/cli/test_cli.py @@ -7,9 +7,10 @@ import os from pathlib import Path +import packaging.version as pv import pytest from cvat_cli.cli import CLI -from cvat_sdk import make_client +from cvat_sdk import Client, make_client from cvat_sdk.api_client import exceptions from cvat_sdk.core.proxies.tasks import ResourceType, Task from PIL import Image @@ -190,6 +191,17 @@ def test_can_create_from_backup(self, fxt_new_task: Task, fxt_backup_file: Path) assert task_id != fxt_new_task.id assert self.client.tasks.retrieve(task_id).size == fxt_new_task.size + def test_can_warn_on_mismatching_server_version(self, monkeypatch, caplog): + def mocked_version(_): + return pv.Version("0") + + # We don't actually run a separate process in the tests here, so it works + monkeypatch.setattr(Client, "get_server_version", mocked_version) + + self.run_cli("ls") + + assert "Server version '0' is not compatible with SDK version" in caplog.text + @pytest.mark.parametrize("verify", [True, False]) def test_can_control_ssl_verification_with_arg(self, monkeypatch, verify: bool): # TODO: Very hacky implementation, improve it, if possible diff --git a/tests/python/rest_api/test_server.py b/tests/python/rest_api/test_server.py new file mode 100644 index 000000000000..d677f21d91f3 --- /dev/null +++ b/tests/python/rest_api/test_server.py @@ -0,0 +1,36 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + + +from http import HTTPStatus +import pytest +from shared.utils.config import make_api_client + + +@pytest.mark.usefixtures('dontchangedb') +class TestGetServer: + def test_can_retrieve_about_unauthorized(self): + with make_api_client(user=None, password=None) as api_client: + (data, response) = api_client.server_api.retrieve_about() + + assert response.status == HTTPStatus.OK + assert data.version + + def test_can_retrieve_formats(self, admin_user: str): + with make_api_client(admin_user) as api_client: + (data, response) = api_client.server_api.retrieve_annotation_formats() + + assert response.status == HTTPStatus.OK + assert len(data.importers) != 0 + assert len(data.exporters) != 0 + + +@pytest.mark.usefixtures('dontchangedb') +class TestGetSchema: + def test_can_get_schema_unauthorized(self): + with make_api_client(user=None, password=None) as api_client: + (data, response) = api_client.schema_api.retrieve() + + assert response.status == HTTPStatus.OK + assert data diff --git a/tests/python/sdk/test_client.py b/tests/python/sdk/test_client.py index e240e3d7625d..7fdf70eee650 100644 --- a/tests/python/sdk/test_client.py +++ b/tests/python/sdk/test_client.py @@ -3,13 +3,15 @@ # SPDX-License-Identifier: MIT import io +from contextlib import ExitStack from logging import Logger from typing import Tuple +import packaging.version as pv import pytest from cvat_sdk import Client from cvat_sdk.core.client import Config, make_client -from cvat_sdk.core.exceptions import InvalidHostException +from cvat_sdk.core.exceptions import IncompatibleVersionException, InvalidHostException from cvat_sdk.exceptions import ApiException from shared.utils.config import BASE_URL, USER_PASS @@ -48,6 +50,13 @@ def test_can_logout(self): assert not self.client.has_credentials() + def test_can_get_server_version(self): + self.client.login((self.user, USER_PASS)) + + version = self.client.get_server_version() + + assert (version.major, version.minor) >= (2, 0) + def test_can_detect_server_schema_if_not_provided(): host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1) @@ -71,6 +80,72 @@ def test_can_reject_invalid_server_schema(): assert capture.match(r"Invalid url schema 'ftp'") +@pytest.mark.parametrize("raise_exception", (True, False)) +def test_can_warn_on_mismatching_server_version( + fxt_logger: Tuple[Logger, io.StringIO], monkeypatch, raise_exception: bool +): + logger, logger_stream = fxt_logger + + def mocked_version(_): + return pv.Version("0") + + monkeypatch.setattr(Client, "get_server_version", mocked_version) + + config = Config() + + with ExitStack() as es: + if raise_exception: + config.allow_unsupported_server = False + es.enter_context(pytest.raises(IncompatibleVersionException)) + + Client(url=BASE_URL, logger=logger, config=config) + + assert "Server version '0' is not compatible with SDK version" in logger_stream.getvalue() + + +@pytest.mark.parametrize("do_check", (True, False)) +def test_can_check_server_version_in_ctor( + fxt_logger: Tuple[Logger, io.StringIO], monkeypatch, do_check: bool +): + logger, logger_stream = fxt_logger + + def mocked_version(_): + return pv.Version("0") + + monkeypatch.setattr(Client, "get_server_version", mocked_version) + + config = Config() + config.allow_unsupported_server = False + + with ExitStack() as es: + if do_check: + es.enter_context(pytest.raises(IncompatibleVersionException)) + + Client(url=BASE_URL, logger=logger, config=config, check_server_version=do_check) + + assert ( + "Server version '0' is not compatible with SDK version" in logger_stream.getvalue() + ) == do_check + + +def test_can_check_server_version_in_method(fxt_logger: Tuple[Logger, io.StringIO], monkeypatch): + logger, logger_stream = fxt_logger + + def mocked_version(_): + return pv.Version("0") + + monkeypatch.setattr(Client, "get_server_version", mocked_version) + + config = Config() + config.allow_unsupported_server = False + client = Client(url=BASE_URL, logger=logger, config=config, check_server_version=False) + + with client, pytest.raises(IncompatibleVersionException): + client.check_server_version() + + assert "Server version '0' is not compatible with SDK version" in logger_stream.getvalue() + + @pytest.mark.parametrize("verify", [True, False]) def test_can_control_ssl_verification_with_config(verify: bool): config = Config(verify_ssl=verify)