Skip to content

Commit

Permalink
Add unconnected input and output variables to the input/output OP of …
Browse files Browse the repository at this point in the history
…a graph
  • Loading branch information
PINTO0309 committed Apr 27, 2022
1 parent ffae9e4 commit c491c22
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sna4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sna4onnx.onnx_operation_adder import add, main

__version__ = '1.0.0'
__version__ = '1.0.1'
46 changes: 46 additions & 0 deletions sna4onnx/onnx_operation_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from argparse import ArgumentParser
import onnx
import onnx_graphsurgeon as gs
from onnx_graphsurgeon.ir.tensor import Variable
import numpy as np
from typing import Optional, List
import sog4onnx
Expand Down Expand Up @@ -286,6 +287,51 @@ def add(

graph.cleanup().toposort()

# Add unconnected input and output variables to the input/output OP of a graph
graph_input_variables = []
graph_output_variables = []

# Extraction of input variables
for graph_node in graph.nodes:
try:
for input in graph_node.inputs:
if isinstance(input, Variable) and input not in graph.inputs:
graph_input_variables.append(input)
except:
pass

# Extraction of output variables
for graph_node in graph.nodes:
try:
for output in graph_node.outputs:
if isinstance(output, Variable) and output not in graph.outputs:
graph_output_variables.append(output)
except:
pass

graph_node_input_names = [
graph_node_input.name for graph_node in graph.nodes for graph_node_input in graph_node.inputs
]
graph_node_output_names = [
graph_node_output.name for graph_node in graph.nodes for graph_node_output in graph_node.outputs
]

# Extract unused input variables and assign them to graph inputs
for graph_input_variable in graph_input_variables:
if graph_input_variable.name in graph_node_output_names:
pass
else:
graph.inputs.append(graph_input_variable)

# Extract unused output variables and assign them to graph output
for graph_output_variable in graph_output_variables:
if graph_output_variable.name in graph_node_input_names:
pass
else:
graph.outputs.append(graph_output_variable)

graph.cleanup().toposort()

# Shape Estimation
changed_graph = None
try:
Expand Down

0 comments on commit c491c22

Please sign in to comment.