Skip to content

Commit

Permalink
Supports OP generation for TensorRT Plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Jun 24, 2023
1 parent 48dd41e commit 06961a6
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 19 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ sog4onnx/saved_model/
__pycache__/

*.onnx
*.npy
*.npy
*.engine
*.profile
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,28 @@ $ sog4onnx \
```
![image](https://user-images.githubusercontent.com/33194443/163311192-b559134f-d42d-4119-8990-0f7ac63230e3.png)

### 6-5. opset=11, EfficientNMS_TRT (TensorRT Efficient NMS Plugin)
```bash
$ sog4onnx \
--op_type EfficientNMS_TRT \
--opset 11 \
--op_name trt_nms_efficient_std_11 \
--input_variables boxes float32 [1,3549,4] \
--input_variables scores float32 [1,3549,16] \
--attributes plugin_version str 1 \
--attributes score_threshold float32 0.25 \
--attributes iou_threshold float32 0.45 \
--attributes max_output_boxes int64 20 \
--attributes background_class int64 -1 \
--attributes score_activation bool False \
--attributes box_coding int64 0 \
--output_variables num_detections int32 [1,1] \
--output_variables detection_boxes float32 [1,20,4] \
--output_variables detection_scores float32 [1,20] \
--output_variables detection_classes int32 [1,20]
```
![image](https://github.com/PINTO0309/sog4onnx/assets/33194443/1b3989fd-cd73-4b1e-af59-cda25ea61a97)

## 7. Reference
1. https://github.com/onnx/onnx/blob/main/docs/Operators.md
2. https://docs.nvidia.com/deeplearning/tensorrt/onnx-graphsurgeon/docs/index.html
Expand Down
2 changes: 1 addition & 1 deletion sog4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sog4onnx.onnx_operation_generator import generate, main

__version__ = '1.0.15'
__version__ = '1.0.16'
52 changes: 35 additions & 17 deletions sog4onnx/onnx_operation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import onnx_graphsurgeon as gs
import numpy as np
from typing import Optional
from collections import OrderedDict

class Color:
BLACK = '\033[30m'
Expand Down Expand Up @@ -40,6 +41,7 @@ class Color:
'int32',
'int64',
'str',
'bool',
]

DTYPES_TO_ONNX_DTYPES = {
Expand All @@ -53,22 +55,24 @@ class Color:
'float64': np.float64,
'int32': np.int32,
'int64': np.int64,
'bool': np.bool_
}

NUMPY_TYPES_TO_ONNX_DTYPES = {
np.dtype('float32'): onnx.TensorProto.FLOAT,
np.dtype('float64'): onnx.TensorProto.DOUBLE,
np.dtype('int32'): onnx.TensorProto.INT32,
np.dtype('int64'): onnx.TensorProto.INT64,
np.dtype('bool_'): onnx.TensorProto.BOOL,
}

def generate(
op_type: str,
opset: int,
op_name: str,
input_variables: Optional[dict] = None,
output_variables: Optional[dict] = None,
attributes: Optional[dict] = None,
input_variables: Optional[OrderedDict] = None,
output_variables: Optional[OrderedDict] = None,
attributes: Optional[OrderedDict] = None,
output_onnx_file_path: Optional[str] = '',
non_verbose: Optional[bool] = False,
) -> onnx.ModelProto:
Expand All @@ -88,21 +92,21 @@ def generate(
op_name: str
OP name.
input_variables: Optional[dict]
input_variables: Optional[OrderedDict]
Specify input variables for the OP to be generated.\n\
See below for the variables that can be specified.\n\n\
{"input_var_name1": [numpy.dtype, shape], "input_var_name2": [dtype, shape], ...}\n\n\
e.g. input_variables = {"name1": [np.float32, [1,224,224,3]], "name2": [np.bool_, [0]], ...}\n\
https://github.com/onnx/onnx/blob/main/docs/Operators.md
output_variables: Optional[dict]
output_variables: Optional[OrderedDict]
Specify output variables for the OP to be generated.\n\
See below for the variables that can be specified.\n\n\
{"output_var_name1": [numpy.dtype, shape], "output_var_name2": [dtype, shape], ...}\n\n\
e.g. output_variables = {"name1": [np.float32, [1,224,224,3]], "name2": [np.bool_, [0]], ...}\n\
https://github.com/onnx/onnx/blob/main/docs/Operators.md
attributes: Optional[dict]
attributes: Optional[OrderedDict]
Specify output attributes for the OP to be generated.\n\
See below for the attributes that can be specified.\n\n\
{"attr_name1": value1, "attr_name2": value2, "attr_name3": value3, ...}\n\n\
Expand Down Expand Up @@ -136,7 +140,9 @@ def generate(
"""
input_gs_variables = None
if input_variables:
input_gs_variables = [gs.Variable(name=key, dtype=value[0], shape=value[1]) for key, value in input_variables.items()]
input_gs_variables = [
gs.Variable(name=key, dtype=value[0], shape=value[1]) for key, value in input_variables.items()
]

"""
output_gs_variables
Expand All @@ -148,7 +154,9 @@ def generate(
"""
output_gs_variables = None
if output_variables:
output_gs_variables = [gs.Variable(name=key, dtype=value[0], shape=value[1]) for key, value in output_variables.items()]
output_gs_variables = [
gs.Variable(name=key, dtype=value[0], shape=value[1]) for key, value in output_variables.items()
]

# 2. Node Generation
node = None
Expand Down Expand Up @@ -219,12 +227,16 @@ def generate(

# 4. Graph Check
try:
onnx.checker.check_model(
model=single_op_graph,
full_check=False
)
if not non_verbose:
print(f'{Color.GREEN}INFO:{Color.RESET} The model is checked!')
if not op_type.endswith('_TRT'):
onnx.checker.check_model(
model=single_op_graph,
full_check=False
)
if not non_verbose:
print(f'{Color.GREEN}INFO:{Color.RESET} The model is checked!')
else:
if not non_verbose:
print(f'{Color.GREEN}INFO:{Color.RESET} Model checker was skipped due to OP regarding TRT plugin.')

except Exception as e:
tracetxt = traceback.format_exc().splitlines()[-1]
Expand Down Expand Up @@ -344,15 +356,21 @@ def main():
"""
input_variables_tmp = None
if args.input_variables:
input_variables_tmp = {input_variable[0]: [getattr(np, input_variable[1]), ast.literal_eval(input_variable[2])] for input_variable in args.input_variables}
input_variables_tmp = \
OrderedDict(
{input_variable[0]: [getattr(np, input_variable[1]), ast.literal_eval(input_variable[2])] for input_variable in args.input_variables}
)

# output variables
"""
output_variables_tmp = {'name': [dtype, shape]}
"""
output_variables_tmp = None
if args.output_variables:
output_variables_tmp = {output_variable[0]: [getattr(np, output_variable[1]), ast.literal_eval(output_variable[2])] for output_variable in args.output_variables}
output_variables_tmp = \
OrderedDict(
{output_variable[0]: [getattr(np, output_variable[1]), ast.literal_eval(output_variable[2])] for output_variable in args.output_variables}
)

# attributes
"""
Expand All @@ -369,7 +387,7 @@ def main():
)
sys.exit(1)

attributes_tmp = {}
attributes_tmp = OrderedDict({})
for attribute in args.attributes:
# parse
attr_name = attribute[0]
Expand Down

0 comments on commit 06961a6

Please sign in to comment.