Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check server version in SDK #4935

Merged
merged 11 commits into from
Sep 30, 2022
7 changes: 6 additions & 1 deletion cvat-cli/src/cvat_cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ def configure_logger(level):
def build_client(parsed_args: SimpleNamespace, logger: logging.Logger) -> Client:
config = Config(verify_ssl=not parsed_args.insecure)

url = parsed_args.server_host
if parsed_args.server_port:
url += f":{parsed_args.server_port}"

return Client(
url="{host}:{port}".format(host=parsed_args.server_host, port=parsed_args.server_port),
url=url,
logger=logger,
config=config,
check_server_version=False, # version is checked after auth to support versions < 2.3
)


Expand Down
2 changes: 2 additions & 0 deletions cvat-cli/src/cvat_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def __init__(self, client: Client, credentials: Tuple[str, str]):

self.client.login(credentials)

self.client.check_server_version(fail_if_unsupported=False)

def tasks_list(self, *, use_json_output: bool = False, **kwargs):
"""List all tasks in either basic or JSON format."""
results = self.client.tasks.list(return_json=use_json_output, **kwargs)
Expand Down
14 changes: 7 additions & 7 deletions cvat-sdk/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ schema/
.openapi-generator/

# Generated code
cvat_sdk/api_client/
cvat_sdk/version.py
requirements/
docs/
setup.py
README.md
MANIFEST.in
/cvat_sdk/api_client/
/cvat_sdk/version.py
/requirements/
/docs/
/setup.py
/README.md
/MANIFEST.in
78 changes: 68 additions & 10 deletions cvat-sdk/cvat_sdk/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,54 @@
from typing import Any, Dict, Optional, Sequence, Tuple

import attrs
import packaging.version as pv
import urllib3
import urllib3.exceptions

from cvat_sdk.api_client import ApiClient, Configuration, models
from cvat_sdk.core.exceptions import InvalidHostException
from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
from cvat_sdk.core.exceptions import IncompatibleVersionException, InvalidHostException
from cvat_sdk.core.helpers import expect_status
from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
from cvat_sdk.core.proxies.jobs import JobsRepo
from cvat_sdk.core.proxies.model_proxy import Repo
from cvat_sdk.core.proxies.projects import ProjectsRepo
from cvat_sdk.core.proxies.tasks import TasksRepo
from cvat_sdk.core.proxies.users import UsersRepo
from cvat_sdk.version import VERSION


@attrs.define
class Config:
status_check_period: float = 5
"""In seconds"""
"""Operation status check period, in seconds"""

allow_unsupported_server: bool = True
"""Allow to use SDK with an unsupported server version. If disabled, raise an exception"""

verify_ssl: Optional[bool] = None
"""
Whether to verify host SSL certificate or not.
"""
"""Whether to verify host SSL certificate or not"""


class Client:
"""
Manages session and configuration.
"""

SUPPORTED_SERVER_VERSIONS = (
pv.Version("2.0"),
pv.Version("2.1"),
pv.Version("2.2"),
pv.Version("2.3"),
)

def __init__(
self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None
):
self,
url: str,
*,
logger: Optional[logging.Logger] = None,
config: Optional[Config] = None,
check_server_version: bool = True,
) -> None:
url = self._validate_and_prepare_url(url)
self.logger = logger or logging.getLogger(__name__)
self.config = config or Config()
Expand All @@ -53,6 +68,9 @@ def __init__(
Configuration(host=self.api_map.host, verify_ssl=self.config.verify_ssl)
)

if check_server_version:
self.check_server_version()

self._repos: Dict[str, Repo] = {}

ALLOWED_SCHEMAS = ("https", "http")
Expand Down Expand Up @@ -87,12 +105,14 @@ def _detect_schema(cls, base_url: str) -> str:
_request_timeout=5, _parse_response=False, _check_status=False
)

if response.status == 401:
if response.status in [200, 401]:
# Server versions prior to 2.3.0 respond with unauthorized
# 2.3.0 allows unauthorized access
return schema

raise InvalidHostException(
"Failed to detect host schema automatically, please check "
"the server url and try to specify schema explicitly"
"the server url and try to specify 'https://' or 'http://' explicitly"
)

def __enter__(self):
Expand Down Expand Up @@ -162,6 +182,44 @@ def wait_for_completion(

return response

def check_server_version(self, fail_if_unsupported: Optional[bool] = None) -> None:
if fail_if_unsupported is None:
fail_if_unsupported = not self.config.allow_unsupported_server

try:
server_version = self.get_server_version()
except exceptions.ApiException as e:
msg = (
"Failed to retrieve server API version: %s. "
"Some SDK functions may not work properly with this server."
) % (e,)
self.logger.warning(msg)
if fail_if_unsupported:
raise IncompatibleVersionException(msg)
return

sdk_version = pv.Version(VERSION)

# We only check base version match. Micro releases and fixes do not affect
# API compatibility in general.
if all(
server_version.base_version != sv.base_version for sv in self.SUPPORTED_SERVER_VERSIONS
):
msg = (
"Server version '%s' is not compatible with SDK version '%s'. "
"Some SDK functions may not work properly with this server. "
"You can continue using this SDK, or you can "
"try to update with 'pip install cvat-sdk'."
) % (server_version, sdk_version)
self.logger.warning(msg)
if fail_if_unsupported:
raise IncompatibleVersionException(msg)

def get_server_version(self) -> pv.Version:
# TODO: allow to use this endpoint unauthorized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to remove the TODO

(about, _) = self.api_client.server_api.retrieve_about()
return pv.Version(about.version)

def _get_repo(self, key: str) -> Repo:
_repo_map = {
"tasks": TasksRepo,
Expand Down
4 changes: 4 additions & 0 deletions cvat-sdk/cvat_sdk/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class CvatSdkException(Exception):

class InvalidHostException(CvatSdkException):
"""Indicates an invalid hostname error"""


class IncompatibleVersionException(CvatSdkException):
"""Indicates server and SDK version mismatch"""
3 changes: 2 additions & 1 deletion cvat-sdk/gen/templates/requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
-r api_client.txt

attrs >= 21.4.0
packaging >= 21.3
Pillow >= 9.0.1
tqdm >= 4.64.0
tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code
typing_extensions >= 4.2.0
typing_extensions >= 4.2.0
2 changes: 1 addition & 1 deletion cvat/apps/engine/tests/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def test_api_v2_server_about_user(self):

def test_api_v2_server_about_no_auth(self):
response = self._run_api_v2_server_about(None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_200_OK)

def test_api_server_about_versions_admin(self):
for version in settings.REST_FRAMEWORK['ALLOWED_VERSIONS']:
Expand Down
4 changes: 3 additions & 1 deletion cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def get_serializer(self, *args, **kwargs):
responses={
'200': AboutSerializer,
})
@action(detail=False, methods=['GET'], serializer_class=AboutSerializer)
@action(detail=False, methods=['GET'], serializer_class=AboutSerializer,
permission_classes=[] # This endpoint is available for everyone
)
def about(request):
from cvat import __version__ as cvat_version
about = {
Expand Down
3 changes: 2 additions & 1 deletion cvat/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import mimetypes
from corsheaders.defaults import default_headers
from distutils.util import strtobool
from cvat import __version__

mimetypes.add_type("application/wasm", ".wasm", True)

Expand Down Expand Up @@ -517,7 +518,7 @@ def add_ssh_keys():
# Statically set schema version. May also be an empty string. When used together with
# view versioning, will become '0.0.0 (v2)' for 'v2' versioned requests.
# Set VERSION to None if only the request version should be rendered.
'VERSION': '2.1.0',
'VERSION': __version__,
'CONTACT': {
'name': 'CVAT.ai team',
'url': 'https://github.com/cvat-ai/cvat',
Expand Down
14 changes: 13 additions & 1 deletion tests/python/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import os
from pathlib import Path

import packaging.version as pv
import pytest
from cvat_cli.cli import CLI
from cvat_sdk import make_client
from cvat_sdk import Client, make_client
from cvat_sdk.api_client import exceptions
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from PIL import Image
Expand Down Expand Up @@ -190,6 +191,17 @@ def test_can_create_from_backup(self, fxt_new_task: Task, fxt_backup_file: Path)
assert task_id != fxt_new_task.id
assert self.client.tasks.retrieve(task_id).size == fxt_new_task.size

def test_can_warn_on_mismatching_server_version(self, monkeypatch, caplog):
def mocked_version(_):
return pv.Version("0")

# We don't actually run a separate process in the tests here, so it works
monkeypatch.setattr(Client, "get_server_version", mocked_version)

self.run_cli("ls")

assert "Server version '0' is not compatible with SDK version" in caplog.text

@pytest.mark.parametrize("verify", [True, False])
def test_can_control_ssl_verification_with_arg(self, monkeypatch, verify: bool):
# TODO: Very hacky implementation, improve it, if possible
Expand Down
36 changes: 36 additions & 0 deletions tests/python/rest_api/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT


from http import HTTPStatus
import pytest
from shared.utils.config import make_api_client


@pytest.mark.usefixtures('dontchangedb')
class TestGetServer:
def test_can_retrieve_about_unauthorized(self):
with make_api_client(user=None, password=None) as api_client:
(data, response) = api_client.server_api.retrieve_about()

assert response.status == HTTPStatus.OK
assert data.version

def test_can_retrieve_formats(self, admin_user: str):
with make_api_client(admin_user) as api_client:
(data, response) = api_client.server_api.retrieve_annotation_formats()

assert response.status == HTTPStatus.OK
assert len(data.importers) != 0
assert len(data.exporters) != 0


@pytest.mark.usefixtures('dontchangedb')
class TestGetSchema:
def test_can_get_schema_unauthorized(self):
with make_api_client(user=None, password=None) as api_client:
(data, response) = api_client.schema_api.retrieve()

assert response.status == HTTPStatus.OK
assert data
Loading