Skip to content

Commit

Permalink
Eliminated comma delimiters in input parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed May 9, 2022
1 parent 5514755 commit 0b4ec07
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 27 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,14 @@ optional arguments:
--forced_extraction_op_names FORCED_EXTRACTION_OP_NAMES
Extracts the constant value of the specified OP name to .npy
regardless of the mode specified.
Specify the name of the OP, separated by commas.
Cannot be used with --forced_extraction_constant_names at the same time.
e.g. --forced_extraction_op_names aaa,bbb,ccc
e.g. --forced_extraction_op_names aaa bbb ccc

--forced_extraction_constant_names FORCED_EXTRACTION_CONSTANT_NAMES
Extracts the constant value of the specified Constant name to .npy
regardless of the mode specified.
Specify the name of the Constant, separated by commas.
Cannot be used with --forced_extraction_op_names at the same time.
e.g. --forced_extraction_constant_names aaa,bbb,ccc
e.g. --forced_extraction_constant_names aaa bbb ccc

--disable_auto_downcast
Disables automatic downcast processing from Float64 to Float32 and INT64
Expand Down
2 changes: 1 addition & 1 deletion scs4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from scs4onnx.onnx_shrink_constant import shrinking, main

__version__ = '1.0.15'
__version__ = '1.0.16'
41 changes: 19 additions & 22 deletions scs4onnx/onnx_shrink_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,19 +231,20 @@ def shrinking(
# Constant Value Extraction
# 1. OP Name
constants = {}
graph_nodes = [node for node in graph.nodes if node.name in forced_extraction_op_names]
for graph_node in graph_nodes:
for graph_node_input in graph_node.inputs:
if not isinstance(graph_node_input, Constant):
continue
if len(graph_node_input.shape) == 0:
continue
if np.isscalar(graph_node_input.values):
continue
constants[graph_node_input.name] = graph_node_input
if forced_extraction_op_names:
graph_nodes = [node for node in graph.nodes if node.name in forced_extraction_op_names]
for graph_node in graph_nodes:
for graph_node_input in graph_node.inputs:
if not isinstance(graph_node_input, Constant):
continue
if len(graph_node_input.shape) == 0:
continue
if np.isscalar(graph_node_input.values):
continue
constants[graph_node_input.name] = graph_node_input

# 2. Constant Name
if len(forced_extraction_constant_names) > 0:
if forced_extraction_constant_names:
for graph_node in graph.nodes:
for graph_node_input in graph_node.inputs:
if graph_node_input.name in forced_extraction_constant_names:
Expand Down Expand Up @@ -355,24 +356,22 @@ def main():
parser.add_argument(
'--forced_extraction_op_names',
type=str,
default='',
nargs='+',
help="\
Extracts the constant value of the specified OP name to .npy \
regardless of the mode specified. \
Specify the name of the OP, separated by commas. \
Cannot be used with --forced_extraction_constant_names at the same time. \
e.g. --forced_extraction_op_names aaa,bbb,ccc"
e.g. --forced_extraction_op_names aaa bbb ccc"
)
parser.add_argument(
'--forced_extraction_constant_names',
type=str,
default='',
nargs='+',
help="\
Extracts the constant value of the specified Constant name to .npy \
regardless of the mode specified. \
Specify the name of the Constant, separated by commas. \
Cannot be used with --forced_extraction_op_names at the same time. \
e.g. --forced_extraction_constant_names aaa,bbb,ccc"
e.g. --forced_extraction_constant_names aaa bbb ccc"
)
parser.add_argument(
'--disable_auto_downcast',
Expand All @@ -399,12 +398,10 @@ def main():
)
sys.exit(1)

forced_extraction_op_names = args.forced_extraction_op_names.strip(' ,').replace(' ','').split(',')
forced_extraction_op_names = [op_name for op_name in forced_extraction_op_names if op_name != '']
forced_extraction_constant_names = args.forced_extraction_constant_names.strip(' ,').replace(' ','').split(',')
forced_extraction_constant_names = [op_name for op_name in forced_extraction_constant_names if op_name != '']
forced_extraction_op_names = args.forced_extraction_op_names
forced_extraction_constant_names = args.forced_extraction_constant_names

if len(forced_extraction_op_names) > 0 and len(forced_extraction_constant_names) > 0:
if forced_extraction_op_names and forced_extraction_constant_names:
print(
f'{Color.RED}ERROR:{Color.RESET} '+
f'Only one of forced_extraction_op_names and forced_extraction_constant_names can be specified. '+
Expand Down

0 comments on commit 0b4ec07

Please sign in to comment.