Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Support dumplicate tensor in the tuple unpack #3340

Merged
merged 2 commits into from
Feb 3, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions nni/common/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node,
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = set()
outputs = set()
inputs = []
outputs = []
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
Expand All @@ -303,17 +303,17 @@ def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node,
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.add(input_name)
inputs.append(input_name)
else:
inputs.add(input_name)
inputs.append(input_name)
else:
inputs.add(input_name)
inputs.append(input_name)
for output in node.outputs():
if output.node().kind() == CONSTANT_KIND:
continue
outputs.add(output.debugName())
outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=list(inputs), outputs=list(outputs), key_node=node)
node_group, inputs=inputs, outputs=outputs, key_node=node)
return nodepy

def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
Expand Down Expand Up @@ -353,8 +353,8 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
if not op_type:
op_type = node.kind()
node_group = [node]
inputs = set()
outputs = set()
inputs = []
outputs = []
node_queue = queue.Queue()
node_queue.put(node)
visited = {node}
Expand All @@ -372,9 +372,9 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
node_queue.put(predecessor_node)
visited.add(predecessor_node)
else:
inputs.add(input_name)
inputs.append(input_name)
else:
inputs.add(input_name)
inputs.append(input_name)
for _output in curr_node.outputs():
if _output.node().kind() == CONSTANT_KIND:
continue
Expand All @@ -387,9 +387,9 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
node_queue.put(successor_node)
visited.add(successor_node)
else:
outputs.add(output_name)
outputs.append(output_name)
else:
outputs.add(output_name)
outputs.append(output_name)

nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=list(inputs), outputs=list(outputs))
Expand Down Expand Up @@ -562,10 +562,13 @@ def _build_index(self, nodes_op):
for node in nodes_op:
name_to_node[node.unique_name] = node
for _input in node.inputs:
input_to_node[_input].append(node)
# inputs may have duplicate tensors
if node not in input_to_node[_input]:
input_to_node[_input].append(node)
for output in node.outputs:
assert not output in output_to_node, \
"One output cannot be generated by multiple nodes %s" % output
if output in output_to_node:
assert output_to_node[output] == node, \
"One output cannot be generated by multiple nodes %s" % output
output_to_node[output] = node
return name_to_node, input_to_node, output_to_node

Expand Down Expand Up @@ -619,8 +622,6 @@ def unpack_manually(self):

assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg
for _debug_input, _debug_output in zip(node.inputs, node.outputs):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
Expand All @@ -629,7 +630,8 @@ def unpack_manually(self):
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self.input_to_node[_debug_input].remove(node)
if node in self.input_to_node[_debug_input]:
self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
# just remove the _debug_output from the grapgh index. So that we can also skip
Expand Down