From 0b4ec07361fa372ba89bee4539a90b985104f1d9 Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Mon, 9 May 2022 22:48:29 +0900 Subject: [PATCH] Eliminated comma delimiters in input parameters --- README.md | 6 ++--- scs4onnx/__init__.py | 2 +- scs4onnx/onnx_shrink_constant.py | 41 +++++++++++++++----------------- 3 files changed, 22 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index b973e3b..92463e2 100644 --- a/README.md +++ b/README.md @@ -71,16 +71,14 @@ optional arguments: --forced_extraction_op_names FORCED_EXTRACTION_OP_NAMES Extracts the constant value of the specified OP name to .npy regardless of the mode specified. - Specify the name of the OP, separated by commas. Cannot be used with --forced_extraction_constant_names at the same time. - e.g. --forced_extraction_op_names aaa,bbb,ccc + e.g. --forced_extraction_op_names aaa bbb ccc --forced_extraction_constant_names FORCED_EXTRACTION_CONSTANT_NAMES Extracts the constant value of the specified Constant name to .npy regardless of the mode specified. - Specify the name of the Constant, separated by commas. Cannot be used with --forced_extraction_op_names at the same time. - e.g. --forced_extraction_constant_names aaa,bbb,ccc + e.g. --forced_extraction_constant_names aaa bbb ccc --disable_auto_downcast Disables automatic downcast processing from Float64 to Float32 and INT64 diff --git a/scs4onnx/__init__.py b/scs4onnx/__init__.py index 50e1aa1..edb0b87 100644 --- a/scs4onnx/__init__.py +++ b/scs4onnx/__init__.py @@ -1,3 +1,3 @@ from scs4onnx.onnx_shrink_constant import shrinking, main -__version__ = '1.0.15' +__version__ = '1.0.16' diff --git a/scs4onnx/onnx_shrink_constant.py b/scs4onnx/onnx_shrink_constant.py index b07d072..6469be4 100644 --- a/scs4onnx/onnx_shrink_constant.py +++ b/scs4onnx/onnx_shrink_constant.py @@ -231,19 +231,20 @@ def shrinking( # Constant Value Extraction # 1. OP Name constants = {} - graph_nodes = [node for node in graph.nodes if node.name in forced_extraction_op_names] - for graph_node in graph_nodes: - for graph_node_input in graph_node.inputs: - if not isinstance(graph_node_input, Constant): - continue - if len(graph_node_input.shape) == 0: - continue - if np.isscalar(graph_node_input.values): - continue - constants[graph_node_input.name] = graph_node_input + if forced_extraction_op_names: + graph_nodes = [node for node in graph.nodes if node.name in forced_extraction_op_names] + for graph_node in graph_nodes: + for graph_node_input in graph_node.inputs: + if not isinstance(graph_node_input, Constant): + continue + if len(graph_node_input.shape) == 0: + continue + if np.isscalar(graph_node_input.values): + continue + constants[graph_node_input.name] = graph_node_input # 2. Constant Name - if len(forced_extraction_constant_names) > 0: + if forced_extraction_constant_names: for graph_node in graph.nodes: for graph_node_input in graph_node.inputs: if graph_node_input.name in forced_extraction_constant_names: @@ -355,24 +356,22 @@ def main(): parser.add_argument( '--forced_extraction_op_names', type=str, - default='', + nargs='+', help="\ Extracts the constant value of the specified OP name to .npy \ regardless of the mode specified. \ - Specify the name of the OP, separated by commas. \ Cannot be used with --forced_extraction_constant_names at the same time. \ - e.g. --forced_extraction_op_names aaa,bbb,ccc" + e.g. --forced_extraction_op_names aaa bbb ccc" ) parser.add_argument( '--forced_extraction_constant_names', type=str, - default='', + nargs='+', help="\ Extracts the constant value of the specified Constant name to .npy \ regardless of the mode specified. \ - Specify the name of the Constant, separated by commas. \ Cannot be used with --forced_extraction_op_names at the same time. \ - e.g. --forced_extraction_constant_names aaa,bbb,ccc" + e.g. --forced_extraction_constant_names aaa bbb ccc" ) parser.add_argument( '--disable_auto_downcast', @@ -399,12 +398,10 @@ def main(): ) sys.exit(1) - forced_extraction_op_names = args.forced_extraction_op_names.strip(' ,').replace(' ','').split(',') - forced_extraction_op_names = [op_name for op_name in forced_extraction_op_names if op_name != ''] - forced_extraction_constant_names = args.forced_extraction_constant_names.strip(' ,').replace(' ','').split(',') - forced_extraction_constant_names = [op_name for op_name in forced_extraction_constant_names if op_name != ''] + forced_extraction_op_names = args.forced_extraction_op_names + forced_extraction_constant_names = args.forced_extraction_constant_names - if len(forced_extraction_op_names) > 0 and len(forced_extraction_constant_names) > 0: + if forced_extraction_op_names and forced_extraction_constant_names: print( f'{Color.RED}ERROR:{Color.RESET} '+ f'Only one of forced_extraction_op_names and forced_extraction_constant_names can be specified. '+