Skip to content

Commit

Permalink
Fixed the memory consumption when creating a task (#2582)
Browse files Browse the repository at this point in the history
* fixed the memory consumption when creating a task

* fixed _get_frame_size function

* updated changelog
  • Loading branch information
Andrey Zhavoronkov authored Dec 16, 2020
1 parent eb349a6 commit f8e9dc3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Memory consumption for the task creation process (<https://github.com/openvinotoolkit/cvat/pull/2582>)

### Security

Expand Down
54 changes: 34 additions & 20 deletions cvat/apps/engine/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,35 @@ def __init__(self, **kwargs):
raise Exception('No sourse path')
self.source_path = kwargs.get('source_path')

def _open_video_container(self, sourse_path, mode, options=None):
@staticmethod
def _open_video_container(sourse_path, mode, options=None):
return av.open(sourse_path, mode=mode, options=options)

def _close_video_container(self, container):
@staticmethod
def _close_video_container(container):
container.close()

def _get_video_stream(self, container):
@staticmethod
def _get_video_stream(container):
video_stream = next(stream for stream in container.streams if stream.type == 'video')
video_stream.thread_type = 'AUTO'
return video_stream

@staticmethod
def _get_frame_size(container):
video_stream = WorkWithVideo._get_video_stream(container)
for packet in container.demux(video_stream):
for frame in packet.decode():
if video_stream.metadata.get('rotate'):
frame = av.VideoFrame().from_ndarray(
rotate_image(
frame.to_ndarray(format='bgr24'),
360 - int(container.streams.video[0].metadata.get('rotate')),
),
format ='bgr24',
)
return frame.width, frame.height

class AnalyzeVideo(WorkWithVideo):
def check_type_first_frame(self):
container = self._open_video_container(self.source_path, mode='r')
Expand Down Expand Up @@ -71,28 +89,21 @@ def __init__(self, **kwargs):
self.key_frames = {}
self.frames = 0

container = self._open_video_container(self.source_path, 'r')
self.width, self.height = self._get_frame_size(container)
self._close_video_container(container)

def get_task_size(self):
return self.frames

@property
def frame_sizes(self):
container = self._open_video_container(self.source_path, 'r')
frame = next(iter(self.key_frames.values()))
if container.streams.video[0].metadata.get('rotate'):
frame = av.VideoFrame().from_ndarray(
rotate_image(
frame.to_ndarray(format='bgr24'),
360 - int(container.streams.video[0].metadata.get('rotate'))
),
format ='bgr24'
)
self._close_video_container(container)
return (frame.width, frame.height)
return (self.width, self.height)

def check_key_frame(self, container, video_stream, key_frame):
for packet in container.demux(video_stream):
for frame in packet.decode():
if md5_hash(frame) != md5_hash(key_frame[1]) or frame.pts != key_frame[1].pts:
if md5_hash(frame) != key_frame[1]['md5'] or frame.pts != key_frame[1]['pts']:
self.key_frames.pop(key_frame[0])
return

Expand All @@ -103,7 +114,7 @@ def check_seek_key_frames(self):
key_frames_copy = self.key_frames.copy()

for key_frame in key_frames_copy.items():
container.seek(offset=key_frame[1].pts, stream=video_stream)
container.seek(offset=key_frame[1]['pts'], stream=video_stream)
self.check_key_frame(container, video_stream, key_frame)

def check_frames_ratio(self, chunk_size):
Expand All @@ -114,10 +125,13 @@ def save_key_frames(self):
video_stream = self._get_video_stream(container)
frame_number = 0

for packet in container.demux(video_stream):
for packet in container.demux(video_stream):
for frame in packet.decode():
if frame.key_frame:
self.key_frames[frame_number] = frame
self.key_frames[frame_number] = {
'pts': frame.pts,
'md5': md5_hash(frame),
}
frame_number += 1

self.frames = frame_number
Expand All @@ -126,7 +140,7 @@ def save_key_frames(self):
def save_meta_info(self):
with open(self.meta_path, 'w') as meta_file:
for index, frame in self.key_frames.items():
meta_file.write('{} {}\n'.format(index, frame.pts))
meta_file.write('{} {}\n'.format(index, frame['pts']))

def get_nearest_left_key_frame(self, start_chunk_frame_number):
start_decode_frame_number = 0
Expand Down

0 comments on commit f8e9dc3

Please sign in to comment.