Skip to content

Commit

Permalink
Fix typing && use ThreadPoolExecutor instead of ThreadPool && fix fai…
Browse files Browse the repository at this point in the history
…led tests
  • Loading branch information
Marishka17 committed May 23, 2024
1 parent 914b7bc commit 7b21a8c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 26 deletions.
47 changes: 28 additions & 19 deletions cvat/apps/engine/cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from abc import ABC, abstractmethod, abstractproperty
from enum import Enum
from io import BytesIO
from multiprocessing.pool import ThreadPool
from typing import Dict, List, Optional, Any, Callable, TypeVar, Iterator
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed

import boto3
from azure.core.exceptions import HttpResponseError, ResourceExistsError
Expand All @@ -33,6 +32,15 @@
from cvat.apps.engine.models import CloudProviderChoice, CredentialsTypeChoice
from cvat.apps.engine.utils import get_cpu_number

class NamedBytesIO(BytesIO):
@property
def filename(self) -> Optional[str]:
return getattr(self, '_filename', None)

@filename.setter
def filename(self, value: str) -> None:
self._filename = value

slogger = ServerLogManager(__name__)

ImageFile.LOAD_TRUNCATED_IMAGES = True
Expand Down Expand Up @@ -138,7 +146,7 @@ def get_file_last_modified(self, key):
pass

@abstractmethod
def download_fileobj(self, key):
def download_fileobj(self, key: str) -> NamedBytesIO:
pass

def download_file(self, key, path):
Expand Down Expand Up @@ -173,7 +181,7 @@ def download_range_of_bytes(self, key: str, stop_byte: int, start_byte: int = 0)
def _download_range_of_bytes(self, key: str, stop_byte: int, start_byte: int):
pass

def optimally_image_download(self, key: str, chunk_size: int = 65536) -> BytesIO:
def optimally_image_download(self, key: str, chunk_size: int = 65536) -> NamedBytesIO:
"""
Method downloads image by the following approach:
Firstly we try to download the first N bytes of image which will be enough for determining image properties.
Expand All @@ -192,15 +200,16 @@ def optimally_image_download(self, key: str, chunk_size: int = 65536) -> BytesIO
image_parser.feed(chunk)

if image_parser.image:
buff = BytesIO(chunk)
buff = NamedBytesIO(chunk)
buff.filename = key
else:
buff = self.download_fileobj(key)
image_size_in_bytes = len(buff.getvalue())
slogger.glob.warning(
f'The {chunk_size} bytes were not enough to parse "{key}" image. '
f'Image size was {image_size_in_bytes} bytes. Image resolution was {Image.open(buff).size}. '
f'Downloaded percent was {round(min(chunk_size, image_size_in_bytes) / image_size_in_bytes * 100)}')
buff.filename = key

return buff

def bulk_download_to_memory(
Expand All @@ -225,13 +234,10 @@ def bulk_download_to_dir(
) -> None:
threads_number = normalize_threads_number(threads_number, len(files))

args = zip(files, [os.path.join(upload_dir, f) for f in files])
if threads_number > 1:
with ThreadPool(threads_number) as pool:
pool.map(lambda x: self.download_file(*x), args)
else:
for f, path in args:
self.download_file(f, path)
with ThreadPoolExecutor(max_workers=threads_number) as executor:
futures = [executor.submit(self.download_file, f, os.path.join(upload_dir, f)) for f in files]
for future in as_completed(futures):
future.result()

@abstractmethod
def upload_fileobj(self, file_obj, file_name):
Expand Down Expand Up @@ -524,14 +530,15 @@ def _list_raw_content_on_one_page(

@validate_file_status
@validate_bucket_status
def download_fileobj(self, key):
buf = BytesIO()
def download_fileobj(self, key: str) -> NamedBytesIO:
buf = NamedBytesIO()
self.bucket.download_fileobj(
Key=key,
Fileobj=buf,
Config=TransferConfig(max_io_queue=self.transfer_config['max_io_queue'])
)
buf.seek(0)
buf.filename = key
return buf

@validate_file_status
Expand Down Expand Up @@ -725,15 +732,16 @@ def _list_raw_content_on_one_page(

@validate_file_status
@validate_bucket_status
def download_fileobj(self, key):
buf = BytesIO()
def download_fileobj(self, key: str) -> NamedBytesIO:
buf = NamedBytesIO()
storage_stream_downloader = self._client.download_blob(
blob=key,
offset=None,
length=None,
)
storage_stream_downloader.download_to_stream(buf, max_concurrency=self.MAX_CONCURRENCY)
buf.seek(0)
buf.filename = key
return buf

@validate_file_status
Expand Down Expand Up @@ -838,11 +846,12 @@ def _list_raw_content_on_one_page(

@validate_file_status
@validate_bucket_status
def download_fileobj(self, key):
buf = BytesIO()
def download_fileobj(self, key: str) -> NamedBytesIO:
buf = NamedBytesIO()
blob = self.bucket.blob(key)
self._client.download_blob_to_file(blob, buf)
buf.seek(0)
buf.filename = key
return buf

@validate_file_status
Expand Down
16 changes: 9 additions & 7 deletions utils/dataset_manifest/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

from .errors import InvalidManifestError, InvalidVideoError
from .utils import SortingMethod, md5_hash, rotate_image, sort
from .types import NamedBytesIO

from typing import Any, Dict, List, Union, Optional, Iterator, Tuple, Callable

from typing import Any, Dict, List, Union, Optional, Iterator, Tuple

class VideoStreamReader:
def __init__(self, source_path, chunk_size, force):
Expand Down Expand Up @@ -141,7 +143,7 @@ def __iter__(self) -> Iterator[Union[int, Tuple[int, int, str]]]:

class DatasetImagesReader:
def __init__(self,
sources: Union[List[str], Iterator[BytesIO]],
sources: Union[List[str], Iterator[NamedBytesIO]],
*,
start: int = 0,
step: int = 1,
Expand All @@ -155,7 +157,7 @@ def __init__(self,

if not self._is_generator_used:
raw_data_used = not isinstance(sources[0], str)
func = (lambda x: x.filename) if raw_data_used else None
func: Optional[Callable[[NamedBytesIO], str]] = (lambda x: x.filename) if raw_data_used else None
self._sources = sort(sources, sorting_method, func=func)
else:
if sorting_method != SortingMethod.PREDEFINED:
Expand Down Expand Up @@ -194,7 +196,7 @@ def step(self):
def step(self, value):
self._step = int(value)

def _get_img_properties(self, image: Union[str, BytesIO]) -> Dict[str, Any]:
def _get_img_properties(self, image: Union[str, NamedBytesIO]) -> Dict[str, Any]:
img = Image.open(image, mode='r')
if self._data_dir:
img_name = os.path.relpath(image, self._data_dir)
Expand Down Expand Up @@ -234,7 +236,7 @@ def __iter__(self):

@property
def range_(self):
return range(self._start, self._stop, self._step)
return range(self._start, self._stop + 1, self._step)

def __len__(self):
return len(self.range_)
Expand All @@ -245,7 +247,7 @@ def __init__(self, **kwargs):

def __iter__(self):
sources = (i for i in self._sources)
for idx in range(self._stop):
for idx in range(self._stop + 1):
if idx in self.range_:
image = next(sources)
img_name = os.path.relpath(image, self._data_dir) if self._data_dir \
Expand Down Expand Up @@ -353,7 +355,7 @@ def partial_update(self, manifest, number):

def __getitem__(self, number):
if not 0 <= number < len(self):
raise IndexError('Invalid index number: {}\nMax: {}'.format(number, len(self) - 1))
raise IndexError('Invalid index number: {}, Maximum allowed index is {}'.format(number, len(self) - 1))

return self._index[number]

Expand Down
12 changes: 12 additions & 0 deletions utils/dataset_manifest/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (C) 2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from io import BytesIO
from typing import Protocol

class Named(Protocol):
filename: str

class NamedBytesIO(BytesIO, Named):
pass

0 comments on commit 7b21a8c

Please sign in to comment.