Skip to content

Commit

Permalink
[Feature] Support visualization for single gpu test (#216)
Browse files Browse the repository at this point in the history
* add single gpu mot visualization

* add mot inference visualization

* fix some typos

* fix some bugs

* support sot and vid visualization

* fix a bug

* fix some bug

* fix sot and vid

* delete dummy code

* refactor video writer

* fix some bugs

* fix some bugs

* fix a bug
  • Loading branch information
ToumaKazusa3 authored Aug 6, 2021
1 parent e5a0c1e commit cac7080
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 21 deletions.
1 change: 0 additions & 1 deletion demo/demo_mot.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def main():
if isinstance(img, str):
img = osp.join(args.input, img)
result = inference_mot(model, img, frame_id=i)
result = result['track_results']
if args.output is not None:
if IN_VIDEO or OUT_VIDEO:
out_file = osp.join(out_path, f'{i:06d}.jpg')
Expand Down
3 changes: 1 addition & 2 deletions demo/demo_sot.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def main():
# test a single image
result = inference_sot(model, frame, init_bbox, frame_id)

track_bbox = result['track_results'][:4]
vis_frame = model.show_result(
frame,
track_bbox,
result,
color=args.color,
thickness=args.thickness,
show=False)
Expand Down
76 changes: 68 additions & 8 deletions mmtrack/apis/test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import os.path as osp
import shutil
import tempfile
Expand All @@ -7,45 +8,104 @@
import mmcv
import torch
import torch.distributed as dist
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info


def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
fps=3,
show_score_thr=0.3):
"""Test model with single gpu.
Args:
model (nn.Module): Model to be tested.
data_loader (nn.Dataloader): Pytorch data loader.
show (bool): If True, visualize the prediction results (Not supported
for now). Defaults to False.
out_dir (str): Path of directory to save the visualization results (Not
supported for now). Defaults to None.
show_score_thr (float): The score threthold of visualization (Not
supported for now). Defaults to 0.3.
show (bool, optional): If True, visualize the prediction results.
Defaults to False.
out_dir (str, optional): Path of directory to save the
visualization results. Defaults to None.
fps (int, optional): FPS of the output video.
Defaults to 3.
show_score_thr (float, optional): The score threshold of visualization
(Only used in VID for now). Defaults to 0.3.
Returns:
dict[str, list]: The prediction results.
"""
model.eval()
results = defaultdict(list)
dataset = data_loader.dataset
prev_img_meta = None
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
for k, v in result.items():
results[k].append(v)

batch_size = data['img'][0].size(0)
if show or out_dir:
pass # TODO
assert batch_size == 1, 'Only support batch_size=1 when testing.'
img_tensor = data['img'][0]
img_meta = data['img_metas'][0].data[0][0]
img = tensor2imgs(img_tensor, **img_meta['img_norm_cfg'])[0]

h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]

ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))

if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None

model.module.show_result(
img_show,
result,
show=show,
out_file=out_file,
score_thr=show_score_thr)

# Whether need to generate a video from images.
# The frame_id == 0 means the model starts processing
# a new video, therefore we can write the previous video.
# There are two corner cases.
# Case 1: prev_img_meta == None means there is no previous video.
# Case 2: i == len(dataset) means processing the last video
need_write_video = (
prev_img_meta is not None and img_meta['frame_id'] == 0
or i == len(dataset))
if out_dir and need_write_video:
prev_img_prefix, prev_img_name = prev_img_meta[
'ori_filename'].rsplit('/', 1)
prev_img_idx, prev_img_type = prev_img_name.split('.')
prev_filename_tmpl = '{:0' + str(
len(prev_img_idx)) + 'd}.' + prev_img_type
prev_img_dirs = f'{out_dir}/{prev_img_prefix}'
prev_img_names = sorted(os.listdir(prev_img_dirs))
prev_start_frame_id = int(prev_img_names[0].split('.')[0])
prev_end_frame_id = int(prev_img_names[-1].split('.')[0])

mmcv.frames2video(
prev_img_dirs,
f'{prev_img_dirs}/out_video.mp4',
fps=fps,
fourcc='mp4v',
filename_tmpl=prev_filename_tmpl,
start=prev_start_frame_id,
end=prev_end_frame_id,
show_progress=False)

prev_img_meta = img_meta

batch_size = data['img'][0].size(0)
for _ in range(batch_size):
prog_bar.update()

return results


Expand Down
13 changes: 10 additions & 3 deletions mmtrack/models/mot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,17 @@ def show_result(self,
show=False,
out_file=None,
wait_time=0,
backend='cv2'):
backend='cv2',
**kwargs):
"""Visualize tracking results.
Args:
img (str | ndarray): Filename of loaded image.
result (list[ndarray]): Tracking results.
result (dict): Tracking result.
The value of key 'track_results' is ndarray with shape (n, 6)
in [id, tl_x, tl_y, br_x, br_y, score] format.
The value of key 'bbox_results' is ndarray with shape (n, 5)
in [tl_x, tl_y, br_x, br_y, score] format.
thickness (int, optional): Thickness of lines. Defaults to 1.
font_scale (float, optional): Font scales of texts. Defaults
to 0.5.
Expand All @@ -265,7 +270,9 @@ def show_result(self,
Returns:
ndarray: Visualized image.
"""
bboxes, labels, ids = restore_result(result, return_ids=True)
assert isinstance(result, dict)
track_result = result.get('track_results', None)
bboxes, labels, ids = restore_result(track_result, return_ids=True)
img = imshow_tracks(
img,
bboxes,
Expand Down
17 changes: 12 additions & 5 deletions mmtrack/models/sot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,15 @@ def show_result(self,
show=False,
win_name='',
wait_time=0,
out_file=None):
out_file=None,
**kwargs):
"""Visualize tracking results.
Args:
img (str or ndarray): The image to be displayed.
result (ndarray): ndarray of shape (4, ).
result (dict): Tracking result.
The value of key 'track_results' is ndarray with shape (5, )
in [tl_x, tl_y, br_x, br_y, score] format.
color (str or tuple or Color, optional): color of bbox.
Defaults to green.
thickness (int, optional): Thickness of lines.
Expand All @@ -289,11 +292,15 @@ def show_result(self,
Returns:
ndarray: Visualized image.
"""
assert result.ndim == 1
assert result.shape[0] == 4
assert isinstance(result, dict)
track_result = result.get('track_results', None)
assert track_result.ndim == 1
assert track_result.shape[0] == 5

track_bbox = track_result[:4]
mmcv.imshow_bboxes(
img,
result[np.newaxis, :],
track_bbox[np.newaxis, :],
colors=color,
thickness=thickness,
show=show,
Expand Down
7 changes: 5 additions & 2 deletions mmtrack/models/vid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,11 @@ def show_result(self,
Args:
img (str or Tensor): The image to be displayed.
result (Tensor or tuple): The results to draw over `img`
bbox_result or (bbox_result, segm_result).
result (dict): The results to draw over `img` bbox_result or
(bbox_result, segm_result). The value of key 'bbox_results'
is list with length num_classes, and each element in list
is ndarray with shape(n, 5)
in [tl_x, tl_y, br_x, br_y, score] format.
score_thr (float, optional): Minimum score of bboxes to be shown.
Default: 0.3.
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
Expand Down

0 comments on commit cac7080

Please sign in to comment.