Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Support the List/Tuple Construct/Unpack operation for TorchModuleGraph #2609

Merged
merged 5 commits into from
Jul 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 77 additions & 9 deletions src/sdk/pynni/nni/_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr'
CAT_KIND = 'aten::cat'
LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
LIST_UNPACK_KIND = 'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct'
TUPLE_UNPACK_KIND = 'prim::TupleUnpack'

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,7 +181,7 @@ class NodePyGroup(NodePy):
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""

def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None):
def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None, key_node=None):
"""
Parameters:
-----------
Expand All @@ -199,6 +203,8 @@ def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None
All the inputs of this node, each element is debugName of one input
outputs: list of str
All the outputs of this node, each element is debugName of one output
key_node: torch._C.Node
The key node of this NodePyGroup.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of key node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Key nodes are the nodes that should not be merged into other nodes. In the past, we only take the aten:: nodes as the important(key) nodes. In this pr, we also take the list/tuple unpack nodes as the key nodes.

"""
super(NodePyGroup, self).__init__(name, [])
self.node_cpps = node_cpps
Expand All @@ -211,6 +217,8 @@ def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None
self.add_nodes(node_cpps)
self.inputs = inputs
self.outputs = outputs
# The core node in this NodePyGroup
self.key_node = key_node

def add_nodes(self, node_cpps):
for node_cpp in node_cpps:
Expand Down Expand Up @@ -239,7 +247,7 @@ def __init__(self, model=None, dummy_input=None, traced_model=None):
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
self._extract_auxiliary_info()

def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node,
def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node,
module_type):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
Expand Down Expand Up @@ -284,7 +292,7 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node,
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node.kind().startswith('prim::'):
if not self._is_key_func(predecessor_node):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
Expand All @@ -294,7 +302,7 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node,
for output in node.outputs():
outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs)
node_group, inputs=inputs, outputs=outputs, key_node=node)
return nodepy

def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
Expand Down Expand Up @@ -510,6 +518,65 @@ def _build_index(self, nodes_op):
output_to_node[output] = node
return name_to_node, input_to_node, output_to_node

def _is_key_func(self, node_cpp):
"""
Judge if a cpp node is a key function node.
If so, we should not merge this node into the
adjacent node.
"""
if node_cpp.kind().startswith('aten::'):
# the nodes that start with 'aten' are key function
# nodes
return True
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not construct type here?

# We cannot merge the List/Tuple
# Construct/Unpack func into other nodes, else it
# may lead to a graph construction error.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this error like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take the shufflenet as an example:
#2581

return True
return False

def unpack_manually(self):
"""
Unpack the tensor tuple or tensor list manually,
and remove the ListUnpack/TupleUnpack node from
the graph. Note: this function will change the
graph structure.
"""
if hasattr(self, 'unpacked'):
# if already unpacked the tuple/list manually
return
for node in self.nodes_py.nodes_op:
if node.op_type in [TUPLE_UNPACK_KIND, LIST_UNPACK_KIND]:
unpack_cpp = node.key_node
last_cpp = list(unpack_cpp.inputs())[0].node()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index is 0, why call it last?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last node actually refers to the previous(last) visited node. In most scenarios, this last_cpp is the corresponding construct node of the tuple/list.

if last_cpp.kind() in [TUPLE_CONSTRUCT_KIND, LIST_CONSTRUCT_KIND]:
# we need check if the tensor tuple or tensor list is produced
# by a list/tuple construct node. If so, we can unpack the tuple
# or list manunally.
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp))
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp))
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little confused about this assert, i think the main reason is i don't understand what is last_cpp

for _input, _output in zip(last_cpp.inputs(), unpack_cpp.outputs()):
_debug_input = _input.debugName()
_debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.

# note that, in this case, the construct cpp node and unpack cpp node
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
if _debug_input in self.output_to_node and _debug_output in self.output_to_node:
# output_to_node[_debug_output] is a NodePyGroup, because one output
# tensor only can be generated by one node.
self.output_to_node[_debug_output] = self.output_to_node[_debug_input]

self.unpacked = True

def _build_graph(self):
"""
Build graph using our defined format from jit trace.
Expand Down Expand Up @@ -585,13 +652,14 @@ def _build_graph(self):
# build node group for torch.nn.functional
for _, nodes in func_to_nodes.items():
# extract non prim:: nodes
non_prim_nodes = list()
key_func_nodes = list()
for node in nodes:
if not node.kind().startswith('prim::'):
non_prim_nodes.append(node)
if self._is_key_func(node):
# find the key function nodes
key_func_nodes.append(node)
# for each non prim node, expand it
for node in non_prim_nodes:
node_group = self._expand_non_prim_node(
for node in key_func_nodes:
node_group = self._expand_key_func_node(
node, nodes, input_to_node, output_to_node, 'func')
nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func
Expand Down
3 changes: 3 additions & 0 deletions src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def build_dependency(self):
Build the channel dependency for the conv layers
in the model.
"""
# unpack the tuple/list manually before analyze the
# channel dependency
self.graph.unpack_manually()
for node in self.graph.nodes_py.nodes_op:
parent_layers = []
# find the node that contains aten::add
Expand Down