From 48a69cf22e5296dd6ed0074e99521ad486e5696a Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Wed, 30 Oct 2024 13:22:33 +0200 Subject: [PATCH] Fix return type annotations for functions used with @contextmanager Some of them are annotated with an `Iterator` return type. However... It just occurred to me that `@contextmanager` cannot work with a function that returns a plain iterator, since it relies on the generator class's `throw` method. `contextmanager` is defined in typeshed as accepting an iterator-returning function, but that appears to be a bug: . Change all such annotations to a `Generator` type instead. Some annotations are also broken in other ways; fix them too. --- cvat-sdk/cvat_sdk/core/client.py | 4 ++-- cvat-sdk/cvat_sdk/core/progress.py | 4 ++-- cvat-sdk/cvat_sdk/core/utils.py | 4 ++-- cvat/apps/engine/media_extractors.py | 4 +++- tests/python/cli/util.py | 4 ++-- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/cvat-sdk/cvat_sdk/core/client.py b/cvat-sdk/cvat_sdk/core/client.py index add7ccb5f3d3..0ae0b88ecad9 100644 --- a/cvat-sdk/cvat_sdk/core/client.py +++ b/cvat-sdk/cvat_sdk/core/client.py @@ -10,7 +10,7 @@ from contextlib import contextmanager, suppress from pathlib import Path from time import sleep -from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, TypeVar +from typing import Any, Dict, Generator, Optional, Sequence, Tuple, TypeVar import attrs import packaging.specifiers as specifiers @@ -121,7 +121,7 @@ def organization_slug(self, org_slug: Optional[str]): self.api_client.default_headers[self._ORG_SLUG_HEADER] = org_slug @contextmanager - def organization_context(self, slug: str) -> Iterator[None]: + def organization_context(self, slug: str) -> Generator[None, None, None]: prev_slug = self.organization_slug self.organization_slug = slug try: diff --git a/cvat-sdk/cvat_sdk/core/progress.py b/cvat-sdk/cvat_sdk/core/progress.py index 7fd2d13a2cd2..fd844de722a0 100644 --- a/cvat-sdk/cvat_sdk/core/progress.py +++ b/cvat-sdk/cvat_sdk/core/progress.py @@ -6,7 +6,7 @@ from __future__ import annotations import contextlib -from typing import ContextManager, Iterable, Optional, TypeVar +from typing import Generator, Iterable, Optional, TypeVar T = TypeVar("T") @@ -26,7 +26,7 @@ class ProgressReporter: """ @contextlib.contextmanager - def task(self, **kwargs) -> ContextManager[None]: + def task(self, **kwargs) -> Generator[None, None, None]: """ Returns a context manager that represents a long-running task for which progress can be reported. diff --git a/cvat-sdk/cvat_sdk/core/utils.py b/cvat-sdk/cvat_sdk/core/utils.py index 0706a2eec613..1ef434e3ad5b 100644 --- a/cvat-sdk/cvat_sdk/core/utils.py +++ b/cvat-sdk/cvat_sdk/core/utils.py @@ -13,7 +13,7 @@ BinaryIO, ContextManager, Dict, - Iterator, + Generator, Literal, Sequence, TextIO, @@ -43,7 +43,7 @@ def atomic_writer( @contextlib.contextmanager def atomic_writer( path: Union[os.PathLike, str], mode: Literal["w", "wb"], encoding: str = "UTF-8" -) -> Iterator[IO]: +) -> Generator[IO, None, None]: """ Returns a context manager that, when entered, returns a handle to a temporary file opened with the specified `mode` and `encoding`. If the context manager diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py index a64637359ff5..c923083b18b3 100644 --- a/cvat/apps/engine/media_extractors.py +++ b/cvat/apps/engine/media_extractors.py @@ -539,7 +539,9 @@ def extract(self): class _AvVideoReading: @contextmanager - def read_av_container(self, source: Union[str, io.BytesIO]) -> av.container.InputContainer: + def read_av_container( + self, source: Union[str, io.BytesIO] + ) -> Generator[av.container.InputContainer, None, None]: if isinstance(source, io.BytesIO): source.seek(0) # required for re-reading diff --git a/tests/python/cli/util.py b/tests/python/cli/util.py index f90f2bb6b73e..ff1173fa4a8d 100644 --- a/tests/python/cli/util.py +++ b/tests/python/cli/util.py @@ -9,7 +9,7 @@ import threading import unittest from pathlib import Path -from typing import Any, Dict, Iterator, List, Union +from typing import Any, Dict, Generator, List, Union import requests @@ -39,7 +39,7 @@ def generate_images(dst_dir: Path, count: int) -> List[Path]: @contextlib.contextmanager -def https_reverse_proxy() -> Iterator[str]: +def https_reverse_proxy() -> Generator[str, None, None]: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 cert_dir = Path(__file__).parent