From 557e308a73296edf17ae0670bf1f7f8bc9e2c512 Mon Sep 17 00:00:00 2001 From: zhiltsov-max Date: Wed, 8 Apr 2020 20:47:58 +0300 Subject: [PATCH] Add chunk iterator cache to frame provider (#1367) * Add chunk iterator cache * fix --- cvat/apps/dataset_manager/bindings.py | 2 +- cvat/apps/engine/frame_provider.py | 51 ++++++++++++++++++++------- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 7a7b80b27859..53a103f02a68 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -26,7 +26,7 @@ def __iter__(self): frames = self._frame_provider.get_frames( self._frame_provider.Quality.ORIGINAL, self._frame_provider.Type.NUMPY_ARRAY) - for item_id, image in enumerate(frames): + for item_id, (image, _) in enumerate(frames): yield datumaro.DatasetItem( id=item_id, image=Image(image), diff --git a/cvat/apps/engine/frame_provider.py b/cvat/apps/engine/frame_provider.py index 7bd60a8a9fbb..25575ea51d36 100644 --- a/cvat/apps/engine/frame_provider.py +++ b/cvat/apps/engine/frame_provider.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: MIT -import itertools import math from enum import Enum from io import BytesIO @@ -15,6 +14,33 @@ from cvat.apps.engine.models import DataChoice +class RandomAccessIterator: + def __init__(self, iterable): + self.iterable = iterable + self.iterator = None + self.pos = -1 + + def __iter__(self): + return self + + def __next__(self): + return self[self.pos + 1] + + def __getitem__(self, idx): + assert 0 <= idx + if self.iterator is None or idx <= self.pos: + self.reset() + v = None + while self.pos < idx: + # NOTE: don't keep the last item in self, it can be expensive + v = next(self.iterator) + self.pos += 1 + return v + + def reset(self): + self.iterator = iter(self.iterable) + self.pos = -1 + class FrameProvider: class Quality(Enum): COMPRESSED = 0 @@ -35,7 +61,8 @@ def __init__(self, reader_class, path_getter): def load(self, chunk_id): if self.chunk_id != chunk_id: self.chunk_id = chunk_id - self.chunk_reader = self.reader_class([self.get_chunk_path(chunk_id)]) + self.chunk_reader = RandomAccessIterator( + self.reader_class([self.get_chunk_path(chunk_id)])) return self.chunk_reader def __init__(self, db_data): @@ -104,18 +131,18 @@ def get_chunk(self, chunk_number, quality=Quality.ORIGINAL): chunk_number = self._validate_chunk_number(chunk_number) return self._loaders[quality].get_chunk_path(chunk_number) - def get_frame(self, frame_number, quality=Quality.ORIGINAL): + def get_frame(self, frame_number, quality=Quality.ORIGINAL, + out_type=Type.BUFFER): _, chunk_number, frame_offset = self._validate_frame_number(frame_number) + loader = self._loaders[quality] + chunk_reader = loader.load(chunk_number) + frame, frame_name, _ = chunk_reader[frame_offset] - chunk_reader = self._loaders[quality].load(chunk_number) - - frame, frame_name, _ = next(itertools.islice(chunk_reader, frame_offset, None)) - if self._loaders[quality].reader_class is VideoReader: - return (self._av_frame_to_png_bytes(frame), 'image/png') + frame = self._convert_frame(frame, loader.reader_class, out_type) + if loader.reader_class is VideoReader: + return (frame, 'image/png') return (frame, mimetypes.guess_type(frame_name)) def get_frames(self, quality=Quality.ORIGINAL, out_type=Type.BUFFER): - loader = self._loaders[quality] - for chunk_idx in range(math.ceil(self._db_data.size / self._db_data.chunk_size)): - for frame, _, _ in loader.load(chunk_idx): - yield self._convert_frame(frame, loader.reader_class, out_type) + for idx in range(self._db_data.size): + yield self.get_frame(idx, quality=quality, out_type=out_type)