Skip to content

Commit

Permalink
Add chunk iterator cache to frame provider (#1367)
Browse files Browse the repository at this point in the history
* Add chunk iterator cache

* fix
  • Loading branch information
zhiltsov-max authored Apr 8, 2020
1 parent afeab69 commit 557e308
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
2 changes: 1 addition & 1 deletion cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __iter__(self):
frames = self._frame_provider.get_frames(
self._frame_provider.Quality.ORIGINAL,
self._frame_provider.Type.NUMPY_ARRAY)
for item_id, image in enumerate(frames):
for item_id, (image, _) in enumerate(frames):
yield datumaro.DatasetItem(
id=item_id,
image=Image(image),
Expand Down
51 changes: 39 additions & 12 deletions cvat/apps/engine/frame_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: MIT

import itertools
import math
from enum import Enum
from io import BytesIO
Expand All @@ -15,6 +14,33 @@
from cvat.apps.engine.models import DataChoice


class RandomAccessIterator:
def __init__(self, iterable):
self.iterable = iterable
self.iterator = None
self.pos = -1

def __iter__(self):
return self

def __next__(self):
return self[self.pos + 1]

def __getitem__(self, idx):
assert 0 <= idx
if self.iterator is None or idx <= self.pos:
self.reset()
v = None
while self.pos < idx:
# NOTE: don't keep the last item in self, it can be expensive
v = next(self.iterator)
self.pos += 1
return v

def reset(self):
self.iterator = iter(self.iterable)
self.pos = -1

class FrameProvider:
class Quality(Enum):
COMPRESSED = 0
Expand All @@ -35,7 +61,8 @@ def __init__(self, reader_class, path_getter):
def load(self, chunk_id):
if self.chunk_id != chunk_id:
self.chunk_id = chunk_id
self.chunk_reader = self.reader_class([self.get_chunk_path(chunk_id)])
self.chunk_reader = RandomAccessIterator(
self.reader_class([self.get_chunk_path(chunk_id)]))
return self.chunk_reader

def __init__(self, db_data):
Expand Down Expand Up @@ -104,18 +131,18 @@ def get_chunk(self, chunk_number, quality=Quality.ORIGINAL):
chunk_number = self._validate_chunk_number(chunk_number)
return self._loaders[quality].get_chunk_path(chunk_number)

def get_frame(self, frame_number, quality=Quality.ORIGINAL):
def get_frame(self, frame_number, quality=Quality.ORIGINAL,
out_type=Type.BUFFER):
_, chunk_number, frame_offset = self._validate_frame_number(frame_number)
loader = self._loaders[quality]
chunk_reader = loader.load(chunk_number)
frame, frame_name, _ = chunk_reader[frame_offset]

chunk_reader = self._loaders[quality].load(chunk_number)

frame, frame_name, _ = next(itertools.islice(chunk_reader, frame_offset, None))
if self._loaders[quality].reader_class is VideoReader:
return (self._av_frame_to_png_bytes(frame), 'image/png')
frame = self._convert_frame(frame, loader.reader_class, out_type)
if loader.reader_class is VideoReader:
return (frame, 'image/png')
return (frame, mimetypes.guess_type(frame_name))

def get_frames(self, quality=Quality.ORIGINAL, out_type=Type.BUFFER):
loader = self._loaders[quality]
for chunk_idx in range(math.ceil(self._db_data.size / self._db_data.chunk_size)):
for frame, _, _ in loader.load(chunk_idx):
yield self._convert_frame(frame, loader.reader_class, out_type)
for idx in range(self._db_data.size):
yield self.get_frame(idx, quality=quality, out_type=out_type)

0 comments on commit 557e308

Please sign in to comment.