Skip to content

Commit

Permalink
Support for SequenceConstruct
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Sep 1, 2022
1 parent 9aac849 commit 98e5340
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion sit4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sit4onnx.onnx_inference_test import inference, main

__version__ = '1.0.4'
__version__ = '1.0.5'
29 changes: 23 additions & 6 deletions sit4onnx/onnx_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,21 @@ def inference(
f'{Color.BLUE}avg elapsed time per pred: {Color.RESET} {e / (test_loop_count - 1) * 1000} ms'
)
for idx, ort_output_name, result in zip(range(1,len(ort_output_names)+1), ort_output_names, results):
print(\
f'{Color.GREEN}INFO:{Color.RESET} '+ \
f'{Color.BLUE}output_name.{idx}:{Color.RESET} {ort_output_name} '+ \
f'{Color.BLUE}shape:{Color.RESET} {[dim for dim in result.shape]} '+ \
f'{Color.BLUE}dtype:{Color.RESET} {result.dtype}'
)
if not isinstance(result, List):
print(\
f'{Color.GREEN}INFO:{Color.RESET} '+ \
f'{Color.BLUE}output_name.{idx}:{Color.RESET} {ort_output_name} '+ \
f'{Color.BLUE}shape:{Color.RESET} {[dim for dim in result.shape]} '+ \
f'{Color.BLUE}dtype:{Color.RESET} {result.dtype}'
)
else:
for sub_idx, sub_result in enumerate(result):
print(\
f'{Color.GREEN}INFO:{Color.RESET} '+ \
f'{Color.BLUE}output_name.{idx}-{sub_idx}:{Color.RESET} {ort_output_name}-{sub_idx} '+ \
f'{Color.BLUE}shape:{Color.RESET} {[dim for dim in sub_result.shape]} '+ \
f'{Color.BLUE}dtype:{Color.RESET} {sub_result.dtype}'
)

# Return
return results
Expand All @@ -385,12 +394,14 @@ def inference(
def main():
parser = ArgumentParser()
parser.add_argument(
'-if',
'--input_onnx_file_path',
type=str,
required=True,
help='Input onnx file path.'
)
parser.add_argument(
'-b',
'--batch_size',
type=int,
default=1,
Expand All @@ -401,6 +412,7 @@ def main():
'numpy_ndarrays_for_testing or fixed_shapes is specified.'
)
parser.add_argument(
'-fs',
'--fixed_shapes',
type=int,
nargs='+',
Expand All @@ -415,6 +427,7 @@ def main():
'--fixed_shapes 1 1 224 224'
)
parser.add_argument(
'-tlc',
'--test_loop_count',
type=int,
default=10,
Expand All @@ -424,13 +437,15 @@ def main():
'and the average inference time per inference is displayed.'
)
parser.add_argument(
'-oep',
'--onnx_execution_provider',
type=str,
choices=ONNX_EXECUTION_PROVIDERS,
default='tensorrt',
help='ONNX Execution Provider.'
)
parser.add_argument(
'-ifp',
'--input_numpy_file_paths_for_testing',
type=str,
action='append',
Expand All @@ -444,11 +459,13 @@ def main():
'--input_numpy_file_paths_for_testing ccc.npy'
)
parser.add_argument(
'-ofp',
'--output_numpy_file',
action='store_true',
help='Outputs the last inference result to an .npy file.'
)
parser.add_argument(
'-n',
'--non_verbose',
action='store_true',
help='Do not show all information logs. Only error logs are displayed.'
Expand Down

0 comments on commit 98e5340

Please sign in to comment.