Skip to content

Commit

Permalink
Check server version in SDK (#4935)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max authored Sep 30, 2022
1 parent 55913f0 commit 2e15025
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 24 deletions.
7 changes: 6 additions & 1 deletion cvat-cli/src/cvat_cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ def configure_logger(level):
def build_client(parsed_args: SimpleNamespace, logger: logging.Logger) -> Client:
config = Config(verify_ssl=not parsed_args.insecure)

url = parsed_args.server_host
if parsed_args.server_port:
url += f":{parsed_args.server_port}"

return Client(
url="{host}:{port}".format(host=parsed_args.server_host, port=parsed_args.server_port),
url=url,
logger=logger,
config=config,
check_server_version=False, # version is checked after auth to support versions < 2.3
)


Expand Down
2 changes: 2 additions & 0 deletions cvat-cli/src/cvat_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def __init__(self, client: Client, credentials: Tuple[str, str]):

self.client.login(credentials)

self.client.check_server_version(fail_if_unsupported=False)

def tasks_list(self, *, use_json_output: bool = False, **kwargs):
"""List all tasks in either basic or JSON format."""
results = self.client.tasks.list(return_json=use_json_output, **kwargs)
Expand Down
14 changes: 7 additions & 7 deletions cvat-sdk/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ schema/
.openapi-generator/

# Generated code
cvat_sdk/api_client/
cvat_sdk/version.py
requirements/
docs/
setup.py
README.md
MANIFEST.in
/cvat_sdk/api_client/
/cvat_sdk/version.py
/requirements/
/docs/
/setup.py
/README.md
/MANIFEST.in
78 changes: 68 additions & 10 deletions cvat-sdk/cvat_sdk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,54 @@
from typing import Any, Dict, Optional, Sequence, Tuple

import attrs
import packaging.version as pv
import urllib3
import urllib3.exceptions

from cvat_sdk.api_client import ApiClient, Configuration, models
from cvat_sdk.core.exceptions import InvalidHostException
from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
from cvat_sdk.core.exceptions import IncompatibleVersionException, InvalidHostException
from cvat_sdk.core.helpers import expect_status
from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
from cvat_sdk.core.proxies.jobs import JobsRepo
from cvat_sdk.core.proxies.model_proxy import Repo
from cvat_sdk.core.proxies.projects import ProjectsRepo
from cvat_sdk.core.proxies.tasks import TasksRepo
from cvat_sdk.core.proxies.users import UsersRepo
from cvat_sdk.version import VERSION


@attrs.define
class Config:
status_check_period: float = 5
"""In seconds"""
"""Operation status check period, in seconds"""

allow_unsupported_server: bool = True
"""Allow to use SDK with an unsupported server version. If disabled, raise an exception"""

verify_ssl: Optional[bool] = None
"""
Whether to verify host SSL certificate or not.
"""
"""Whether to verify host SSL certificate or not"""


class Client:
"""
Manages session and configuration.
"""

SUPPORTED_SERVER_VERSIONS = (
pv.Version("2.0"),
pv.Version("2.1"),
pv.Version("2.2"),
pv.Version("2.3"),
)

def __init__(
self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None
):
self,
url: str,
*,
logger: Optional[logging.Logger] = None,
config: Optional[Config] = None,
check_server_version: bool = True,
) -> None:
url = self._validate_and_prepare_url(url)
self.logger = logger or logging.getLogger(__name__)
self.config = config or Config()
Expand All @@ -53,6 +68,9 @@ def __init__(
Configuration(host=self.api_map.host, verify_ssl=self.config.verify_ssl)
)

if check_server_version:
self.check_server_version()

self._repos: Dict[str, Repo] = {}

ALLOWED_SCHEMAS = ("https", "http")
Expand Down Expand Up @@ -87,12 +105,14 @@ def _detect_schema(cls, base_url: str) -> str:
_request_timeout=5, _parse_response=False, _check_status=False
)

if response.status == 401:
if response.status in [200, 401]:
# Server versions prior to 2.3.0 respond with unauthorized
# 2.3.0 allows unauthorized access
return schema

raise InvalidHostException(
"Failed to detect host schema automatically, please check "
"the server url and try to specify schema explicitly"
"the server url and try to specify 'https://' or 'http://' explicitly"
)

def __enter__(self):
Expand Down Expand Up @@ -162,6 +182,44 @@ def wait_for_completion(

return response

def check_server_version(self, fail_if_unsupported: Optional[bool] = None) -> None:
if fail_if_unsupported is None:
fail_if_unsupported = not self.config.allow_unsupported_server

try:
server_version = self.get_server_version()
except exceptions.ApiException as e:
msg = (
"Failed to retrieve server API version: %s. "
"Some SDK functions may not work properly with this server."
) % (e,)
self.logger.warning(msg)
if fail_if_unsupported:
raise IncompatibleVersionException(msg)
return

sdk_version = pv.Version(VERSION)

# We only check base version match. Micro releases and fixes do not affect
# API compatibility in general.
if all(
server_version.base_version != sv.base_version for sv in self.SUPPORTED_SERVER_VERSIONS
):
msg = (
"Server version '%s' is not compatible with SDK version '%s'. "
"Some SDK functions may not work properly with this server. "
"You can continue using this SDK, or you can "
"try to update with 'pip install cvat-sdk'."
) % (server_version, sdk_version)
self.logger.warning(msg)
if fail_if_unsupported:
raise IncompatibleVersionException(msg)

def get_server_version(self) -> pv.Version:
# TODO: allow to use this endpoint unauthorized
(about, _) = self.api_client.server_api.retrieve_about()
return pv.Version(about.version)

def _get_repo(self, key: str) -> Repo:
_repo_map = {
"tasks": TasksRepo,
Expand Down
4 changes: 4 additions & 0 deletions cvat-sdk/cvat_sdk/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class CvatSdkException(Exception):

class InvalidHostException(CvatSdkException):
"""Indicates an invalid hostname error"""


class IncompatibleVersionException(CvatSdkException):
"""Indicates server and SDK version mismatch"""
3 changes: 2 additions & 1 deletion cvat-sdk/gen/templates/requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
-r api_client.txt

attrs >= 21.4.0
packaging >= 21.3
Pillow >= 9.0.1
tqdm >= 4.64.0
tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code
typing_extensions >= 4.2.0
typing_extensions >= 4.2.0
2 changes: 1 addition & 1 deletion cvat/apps/engine/tests/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def test_api_v2_server_about_user(self):

def test_api_v2_server_about_no_auth(self):
response = self._run_api_v2_server_about(None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_200_OK)

def test_api_server_about_versions_admin(self):
for version in settings.REST_FRAMEWORK['ALLOWED_VERSIONS']:
Expand Down
4 changes: 3 additions & 1 deletion cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def get_serializer(self, *args, **kwargs):
responses={
'200': AboutSerializer,
})
@action(detail=False, methods=['GET'], serializer_class=AboutSerializer)
@action(detail=False, methods=['GET'], serializer_class=AboutSerializer,
permission_classes=[] # This endpoint is available for everyone
)
def about(request):
from cvat import __version__ as cvat_version
about = {
Expand Down
3 changes: 2 additions & 1 deletion cvat/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import mimetypes
from corsheaders.defaults import default_headers
from distutils.util import strtobool
from cvat import __version__

mimetypes.add_type("application/wasm", ".wasm", True)

Expand Down Expand Up @@ -517,7 +518,7 @@ def add_ssh_keys():
# Statically set schema version. May also be an empty string. When used together with
# view versioning, will become '0.0.0 (v2)' for 'v2' versioned requests.
# Set VERSION to None if only the request version should be rendered.
'VERSION': '2.1.0',
'VERSION': __version__,
'CONTACT': {
'name': 'CVAT.ai team',
'url': 'https://github.com/cvat-ai/cvat',
Expand Down
14 changes: 13 additions & 1 deletion tests/python/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import os
from pathlib import Path

import packaging.version as pv
import pytest
from cvat_cli.cli import CLI
from cvat_sdk import make_client
from cvat_sdk import Client, make_client
from cvat_sdk.api_client import exceptions
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from PIL import Image
Expand Down Expand Up @@ -190,6 +191,17 @@ def test_can_create_from_backup(self, fxt_new_task: Task, fxt_backup_file: Path)
assert task_id != fxt_new_task.id
assert self.client.tasks.retrieve(task_id).size == fxt_new_task.size

def test_can_warn_on_mismatching_server_version(self, monkeypatch, caplog):
def mocked_version(_):
return pv.Version("0")

# We don't actually run a separate process in the tests here, so it works
monkeypatch.setattr(Client, "get_server_version", mocked_version)

self.run_cli("ls")

assert "Server version '0' is not compatible with SDK version" in caplog.text

@pytest.mark.parametrize("verify", [True, False])
def test_can_control_ssl_verification_with_arg(self, monkeypatch, verify: bool):
# TODO: Very hacky implementation, improve it, if possible
Expand Down
36 changes: 36 additions & 0 deletions tests/python/rest_api/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT


from http import HTTPStatus
import pytest
from shared.utils.config import make_api_client


@pytest.mark.usefixtures('dontchangedb')
class TestGetServer:
def test_can_retrieve_about_unauthorized(self):
with make_api_client(user=None, password=None) as api_client:
(data, response) = api_client.server_api.retrieve_about()

assert response.status == HTTPStatus.OK
assert data.version

def test_can_retrieve_formats(self, admin_user: str):
with make_api_client(admin_user) as api_client:
(data, response) = api_client.server_api.retrieve_annotation_formats()

assert response.status == HTTPStatus.OK
assert len(data.importers) != 0
assert len(data.exporters) != 0


@pytest.mark.usefixtures('dontchangedb')
class TestGetSchema:
def test_can_get_schema_unauthorized(self):
with make_api_client(user=None, password=None) as api_client:
(data, response) = api_client.schema_api.retrieve()

assert response.status == HTTPStatus.OK
assert data
Loading

0 comments on commit 2e15025

Please sign in to comment.