Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix a bug in tuple unpack.
Browse files Browse the repository at this point in the history
The input may be reused in the Tuple_constuct node.
for example:
['input1', 'input1', 'input2']
  • Loading branch information
Ningxin authored and Ningxin committed Jan 27, 2021
1 parent 64efd60 commit 9272bea
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions nni/common/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, model=None, dummy_input=None, traced_model=None):
self.trace = traced_model
# it's ok if the graph is already unpacked
torch._C._jit_pass_inline(self.trace.graph)

elif model is not None and dummy_input is not None:
self.bound_model = model
self._trace(model, dummy_input)
Expand Down Expand Up @@ -617,8 +618,9 @@ def unpack_manually(self):
errmsg = '%s Input number: %d if inconsistent with the output number %d' % (unpack_cpp, \
len(node.inputs), len(list(last_cpp.inputs())))

assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg
for _debug_input, _debug_output in zip(node.inputs, node.outputs):
in_debugnames = [x.debugName() for x in list(last_cpp.inputs())]
assert len(in_debugnames) == len(node.outputs), errmsg
for _debug_input, _debug_output in zip(in_debugnames, node.outputs):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
Expand All @@ -629,7 +631,8 @@ def unpack_manually(self):
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self.input_to_node[_debug_input].remove(node)
if node in self.input_to_node:
self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
# just remove the _debug_output from the grapgh index. So that we can also skip
Expand Down

0 comments on commit 9272bea

Please sign in to comment.