Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
support the scenario that there are duplicate tensors in a same tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
Ningxin authored and Ningxin committed Jan 27, 2021
1 parent 64efd60 commit 81ed1f1
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions nni/common/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node,
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = set()
outputs = set()
inputs = []
outputs = []
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
Expand All @@ -303,17 +303,17 @@ def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node,
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.add(input_name)
inputs.append(input_name)
else:
inputs.add(input_name)
inputs.append(input_name)
else:
inputs.add(input_name)
inputs.append(input_name)
for output in node.outputs():
if output.node().kind() == CONSTANT_KIND:
continue
outputs.add(output.debugName())
outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=list(inputs), outputs=list(outputs), key_node=node)
node_group, inputs=inputs, outputs=outputs, key_node=node)
return nodepy

def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
Expand Down Expand Up @@ -353,8 +353,8 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
if not op_type:
op_type = node.kind()
node_group = [node]
inputs = set()
outputs = set()
inputs = []
outputs = []
node_queue = queue.Queue()
node_queue.put(node)
visited = {node}
Expand All @@ -372,9 +372,9 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
node_queue.put(predecessor_node)
visited.add(predecessor_node)
else:
inputs.add(input_name)
inputs.append(input_name)
else:
inputs.add(input_name)
inputs.append(input_name)
for _output in curr_node.outputs():
if _output.node().kind() == CONSTANT_KIND:
continue
Expand All @@ -387,9 +387,9 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
node_queue.put(successor_node)
visited.add(successor_node)
else:
outputs.add(output_name)
outputs.append(output_name)
else:
outputs.add(output_name)
outputs.append(output_name)

nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=list(inputs), outputs=list(outputs))
Expand Down Expand Up @@ -562,10 +562,13 @@ def _build_index(self, nodes_op):
for node in nodes_op:
name_to_node[node.unique_name] = node
for _input in node.inputs:
input_to_node[_input].append(node)
# inputs may have duplicate tensors
if node not in input_to_node[_input]:
input_to_node[_input].append(node)
for output in node.outputs:
assert not output in output_to_node, \
"One output cannot be generated by multiple nodes %s" % output
if output in output_to_node:
assert output_to_node[output] == node, \
"One output cannot be generated by multiple nodes %s" % output
output_to_node[output] = node
return name_to_node, input_to_node, output_to_node

Expand Down Expand Up @@ -619,8 +622,6 @@ def unpack_manually(self):

assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg
for _debug_input, _debug_output in zip(node.inputs, node.outputs):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
Expand All @@ -629,7 +630,8 @@ def unpack_manually(self):
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self.input_to_node[_debug_input].remove(node)
if node in self.input_to_node[_debug_input]:
self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
# just remove the _debug_output from the grapgh index. So that we can also skip
Expand Down

0 comments on commit 81ed1f1

Please sign in to comment.