Skip to content

Commit

Permalink
[Improvement] Test with onnx models and TensorRT engines. (#758)
Browse files Browse the repository at this point in the history
* tensorrt inference first commit

* support test with onnx model

* update docs

* update changelog

* update docs

* update changelog

* update

* update changelog
  • Loading branch information
irvingzhang0512 authored Apr 7, 2021
1 parent a2cbd11 commit 6a252b8
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 42 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
**Improvements**

- Add softmax option for pytorch2onnx tool ([#781](https://github.com/open-mmlab/mmaction2/pull/781))
- Test with onnx models and TensorRT engines ([#758](https://github.com/open-mmlab/mmaction2/pull/758))

**Bug and Typo Fixes**

Expand Down
12 changes: 11 additions & 1 deletion docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ You can use the following commands to test a dataset.
# single-gpu testing
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] \
[--gpu-collect] [--tmpdir ${TMPDIR}] [--options ${OPTIONS}] [--average-clips ${AVG_TYPE}] \
[--launcher ${JOB_LAUNCHER}] [--local_rank ${LOCAL_RANK}]
[--launcher ${JOB_LAUNCHER}] [--local_rank ${LOCAL_RANK}] [--onnx] [--tensorrt]

# multi-gpu testing
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] \
Expand All @@ -86,6 +86,8 @@ Optional arguments:
- `AVG_TYPE`: Items to average the test clips. If set to `prob`, it will apply softmax before averaging the clip scores. Otherwise, it will directly average the clip scores.
- `JOB_LAUNCHER`: Items for distributed job initialization launcher. Allowed choices are `none`, `pytorch`, `slurm`, `mpi`. Especially, if set to none, it will test in a non-distributed mode.
- `LOCAL_RANK`: ID for local rank. If not specified, it will be set to 0.
- `--onnx`: If specified, recognition results will be generated by onnx model and `CHECKPOINT_FILE` should be onnx model file path. Onnx model files are generated by `/tools/pytorch2onnx.py`. For now, multi-gpu mode and dynamic input shape mode are not supported. Please note that the output tensors of dataset and the input tensors of onnx model should share the same shape. And it is recommended to remove all test-time augmentation methods in `test_pipeline`(`ThreeCrop`, `TenCrop`, `twice_sample`, etc.)
- `--tensorrt`: If specified, recognition results will be generated by TensorRT engine and `CHECKPOINT_FILE` should be TensorRT engine file path. TensorRT engines are generated by exported onnx models and TensorRT official convertion tools. For now, multi-gpu mode and dynamic input shape mode are not supported. Please note that the output tensors of dataset and the input tensors of TensorRT engine should share the same shape. And it is recommended to remove all test-time augmentation methods in `test_pipeline`(`ThreeCrop`, `TenCrop`, `twice_sample`, etc.)

Examples:

Expand Down Expand Up @@ -115,6 +117,14 @@ Assume that you have already downloaded the checkpoints to the directory `checkp
--launcher slurm --eval top_k_accuracy
```

4. Test TSN on Something-Something V1 with onnx model and evaluate the top-k accuracy

```shell
python tools/test.py configs/recognition/tsn/tsn_r50_1x1x3_100e_kinetics400_rgb.py \
checkpoints/SOME_CHECKPOINT.onnx \
--eval top_k_accuracy --onnx
```

### High-level APIs for testing a video and rawframes

Here is an example of building the model and testing a given video.
Expand Down
14 changes: 12 additions & 2 deletions docs_zh_CN/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
- [使用单个 GPU 进行训练](#使用单个-GPU-进行训练)
- [使用多个 GPU 进行训练](#使用多个-GPU-进行训练)
- [使用多台机器进行训练](#使用多台机器进行训练)
- [使用单台机器创建多个任务](#使用单台机器启动多个任务)
- [使用单台机器启动多个任务](#使用单台机器启动多个任务)
- [详细教程](#详细教程)

<!-- TOC -->
Expand Down Expand Up @@ -67,7 +67,7 @@ MMAction2 提供了一些脚本用于测试数据集(如 Kinetics-400,Someth
# 单 GPU 测试
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] \
[--gpu-collect] [--tmpdir ${TMPDIR}] [--options ${OPTIONS}] [--average-clips ${AVG_TYPE}] \
[--launcher ${JOB_LAUNCHER}] [--local_rank ${LOCAL_RANK}]
[--launcher ${JOB_LAUNCHER}] [--local_rank ${LOCAL_RANK}] [--onnx] [--tensorrt]

# 多 GPU 测试
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] \
Expand All @@ -85,6 +85,8 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
- `AVG_TYPE`:用于平均测试片段结果的选项。如果被设置为 `prob`,则会在平均测试片段结果之前施加 softmax 函数。否则,会直接进行平均。
- `JOB_LAUNCHER`:分布式任务初始化启动器选项。可选值有 `none``pytorch``slurm``mpi`。特别地,如果被设置为 `none`, 则会以非分布式模式进行测试。
- `LOCAL_RANK`:本地 rank 的 ID。如果没有被指定,则会被设置为 0。
- `--onnx`: 如果指定,将通过 onnx 模型推理获取预测结果,输入参数 `CHECKPOINT_FILE` 应为 onnx 模型文件。Onnx 模型文件由 `/tools/pytorch2onnx.py` 脚本导出。目前,不支持多 GPU 测试以及动态张量形状(Dynamic shape)。请注意,数据集输出与模型输入张量的形状应保持一致。同时,不建议使用测试时数据增强,如 `ThreeCrop``TenCrop``twice_sample` 等。
- `--tensorrt`: 如果指定,将通过 TensorRT 模型推理获取预测结果,输入参数 `CHECKPOINT_FILE` 应为 TensorRT 模型文件。TensorRT 模型文件由导出的 onnx 模型以及 TensorRT 官方模型转换工具生成。目前,不支持多 GPU 测试以及动态张量形状(Dynamic shape)。请注意,数据集输出与模型输入张量的形状应保持一致。同时,不建议使用测试时数据增强,如 `ThreeCrop``TenCrop``twice_sample` 等。

例子:

Expand Down Expand Up @@ -114,6 +116,14 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
--launcher slurm --eval top_k_accuracy
```

4. 在 Something-Something V1 下测试 onnx 格式的 TSN 模型,并验证 `top-k accuracy` 指标

```shell
python tools/test.py configs/recognition/tsn/tsn_r50_1x1x3_100e_kinetics400_rgb.py \
checkpoints/SOME_CHECKPOINT.onnx \
--eval top_k_accuracy --onnx
```

### 使用高级 API 对视频和帧文件夹进行测试

这里举例说明如何构建模型并测试给定视频
Expand Down
197 changes: 158 additions & 39 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--onnx',
action='store_true',
help='Whether to test with onnx model or not')
parser.add_argument(
'--tensorrt',
action='store_true',
help='Whether to test with TensorRT engine or not')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
Expand All @@ -106,9 +114,152 @@ def turn_off_pretrained(cfg):
turn_off_pretrained(sub_cfg)


def inference_pytorch(args, cfg, distributed, data_loader):
"""Get predictions by pytorch models."""
if args.average_clips is not None:
# You can set average_clips during testing, it will override the
# original setting
if cfg.model.get('test_cfg') is None and cfg.get('test_cfg') is None:
cfg.model.setdefault('test_cfg',
dict(average_clips=args.average_clips))
else:
if cfg.model.get('test_cfg') is not None:
cfg.model.test_cfg.average_clips = args.average_clips
else:
cfg.test_cfg.average_clips = args.average_clips

# remove redundant pretrain steps for testing
turn_off_pretrained(cfg.model)

# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))

if len(cfg.module_hooks) > 0:
register_module_hooks(model, cfg.module_hooks)

fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
load_checkpoint(model, args.checkpoint, map_location='cpu')

if args.fuse_conv_bn:
model = fuse_conv_bn(model)

if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)

return outputs


def inference_tensorrt(ckpt_path, distributed, data_loader, batch_size):
"""Get predictions by TensorRT engine.
For now, multi-gpu mode and dynamic tensor shape are not supported.
"""
assert not distributed, \
'TensorRT engine inference only supports single gpu mode.'
import tensorrt as trt
from mmcv.tensorrt.tensorrt_utils import (torch_dtype_from_trt,
torch_device_from_trt)

# load engine
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
with open(ckpt_path, mode='rb') as f:
engine_bytes = f.read()
engine = runtime.deserialize_cuda_engine(engine_bytes)

# For now, only support fixed input tensor
cur_batch_size = engine.get_binding_shape(0)[0]
assert batch_size == cur_batch_size, \
('Dataset and TensorRT model should share the same batch size, '
f'but get {batch_size} and {cur_batch_size}')

context = engine.create_execution_context()

# get output tensor
dtype = torch_dtype_from_trt(engine.get_binding_dtype(1))
shape = tuple(context.get_binding_shape(1))
device = torch_device_from_trt(engine.get_location(1))
output = torch.empty(
size=shape, dtype=dtype, device=device, requires_grad=False)

# get predictions
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for data in data_loader:
bindings = [
data['imgs'].contiguous().data_ptr(),
output.contiguous().data_ptr()
]
context.execute_async_v2(bindings,
torch.cuda.current_stream().cuda_stream)
results.extend(output.cpu().numpy())
batch_size = len(next(iter(data.values())))
for _ in range(batch_size):
prog_bar.update()
return results


def inference_onnx(ckpt_path, distributed, data_loader, batch_size):
"""Get predictions by ONNX.
For now, multi-gpu mode and dynamic tensor shape are not supported.
"""
assert not distributed, 'ONNX inference only supports single gpu mode.'

import onnx
import onnxruntime as rt

# get input tensor name
onnx_model = onnx.load(ckpt_path)
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert len(net_feed_input) == 1

# For now, only support fixed tensor shape
input_tensor = None
for tensor in onnx_model.graph.input:
if tensor.name == net_feed_input[0]:
input_tensor = tensor
break
cur_batch_size = input_tensor.type.tensor_type.shape.dim[0].dim_value
assert batch_size == cur_batch_size, \
('Dataset and ONNX model should share the same batch size, '
f'but get {batch_size} and {cur_batch_size}')

# get predictions
sess = rt.InferenceSession(ckpt_path)
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for data in data_loader:
imgs = data['imgs'].cpu().numpy()
onnx_result = sess.run(None, {net_feed_input[0]: imgs})[0]
results.extend(onnx_result)
batch_size = len(next(iter(data.values())))
for _ in range(batch_size):
prog_bar.update()
return results


def main():
args = parse_args()

if args.tensorrt and args.onnx:
raise ValueError(
'Cannot set onnx mode and tensorrt mode at the same time.')

cfg = Config.fromfile(args.config)

cfg.merge_from_dict(args.cfg_options)
Expand Down Expand Up @@ -158,18 +309,6 @@ def main():
torch.backends.cudnn.benchmark = True
cfg.data.test.test_mode = True

if args.average_clips is not None:
# You can set average_clips during testing, it will override the
# original setting
if cfg.model.get('test_cfg') is None and cfg.get('test_cfg') is None:
cfg.model.setdefault('test_cfg',
dict(average_clips=args.average_clips))
else:
if cfg.model.get('test_cfg') is not None:
cfg.model.test_cfg.average_clips = args.average_clips
else:
cfg.test_cfg.average_clips = args.average_clips

# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
Expand All @@ -191,34 +330,14 @@ def main():
**cfg.data.get('test_dataloader', {}))
data_loader = build_dataloader(dataset, **dataloader_setting)

# remove redundant pretrain steps for testing
turn_off_pretrained(cfg.model)

# build the model and load checkpoint
model = build_model(
cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))

if len(cfg.module_hooks) > 0:
register_module_hooks(model, cfg.module_hooks)

fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
load_checkpoint(model, args.checkpoint, map_location='cpu')

if args.fuse_conv_bn:
model = fuse_conv_bn(model)

if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
if args.tensorrt:
outputs = inference_tensorrt(args.checkpoint, distributed, data_loader,
dataloader_setting['videos_per_gpu'])
elif args.onnx:
outputs = inference_onnx(args.checkpoint, distributed, data_loader,
dataloader_setting['videos_per_gpu'])
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
outputs = inference_pytorch(args, cfg, distributed, data_loader)

rank, _ = get_dist_info()
if rank == 0:
Expand Down

0 comments on commit 6a252b8

Please sign in to comment.