-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Support the List/Tuple Construct/Unpack operation for TorchModuleGraph #2609
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,10 @@ | |
CLASSTYPE_KIND = 'ClassType' | ||
GETATTR_KIND = 'prim::GetAttr' | ||
CAT_KIND = 'aten::cat' | ||
LIST_CONSTRUCT_KIND = 'prim::ListConstruct' | ||
LIST_UNPACK_KIND = 'prim::ListUnpack' | ||
TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct' | ||
TUPLE_UNPACK_KIND = 'prim::TupleUnpack' | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
@@ -177,7 +181,7 @@ class NodePyGroup(NodePy): | |
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node. | ||
""" | ||
|
||
def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None): | ||
def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None, key_node=None): | ||
""" | ||
Parameters: | ||
----------- | ||
|
@@ -199,6 +203,8 @@ def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None | |
All the inputs of this node, each element is debugName of one input | ||
outputs: list of str | ||
All the outputs of this node, each element is debugName of one output | ||
key_node: torch._C.Node | ||
The key node of this NodePyGroup. | ||
""" | ||
super(NodePyGroup, self).__init__(name, []) | ||
self.node_cpps = node_cpps | ||
|
@@ -211,6 +217,8 @@ def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None | |
self.add_nodes(node_cpps) | ||
self.inputs = inputs | ||
self.outputs = outputs | ||
# The core node in this NodePyGroup | ||
self.key_node = key_node | ||
|
||
def add_nodes(self, node_cpps): | ||
for node_cpp in node_cpps: | ||
|
@@ -239,7 +247,7 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): | |
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() | ||
self._extract_auxiliary_info() | ||
|
||
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, | ||
def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node, | ||
module_type): | ||
""" | ||
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by | ||
|
@@ -284,7 +292,7 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, | |
input_name = _input.debugName() | ||
if input_name in output_to_node and output_to_node[input_name] in nodes: | ||
predecessor_node = output_to_node[input_name] | ||
if predecessor_node.kind().startswith('prim::'): | ||
if not self._is_key_func(predecessor_node): | ||
node_group.append(predecessor_node) | ||
node_queue.put(predecessor_node) | ||
else: | ||
|
@@ -294,7 +302,7 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, | |
for output in node.outputs(): | ||
outputs.append(output.debugName()) | ||
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, | ||
node_group, inputs=inputs, outputs=outputs) | ||
node_group, inputs=inputs, outputs=outputs, key_node=node) | ||
return nodepy | ||
|
||
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, | ||
|
@@ -510,6 +518,65 @@ def _build_index(self, nodes_op): | |
output_to_node[output] = node | ||
return name_to_node, input_to_node, output_to_node | ||
|
||
def _is_key_func(self, node_cpp): | ||
""" | ||
Judge if a cpp node is a key function node. | ||
If so, we should not merge this node into the | ||
adjacent node. | ||
""" | ||
if node_cpp.kind().startswith('aten::'): | ||
# the nodes that start with 'aten' are key function | ||
# nodes | ||
return True | ||
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not construct type here? |
||
# We cannot merge the List/Tuple | ||
# Construct/Unpack func into other nodes, else it | ||
# may lead to a graph construction error. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this error like? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Take the shufflenet as an example: |
||
return True | ||
return False | ||
|
||
def unpack_manually(self): | ||
""" | ||
Unpack the tensor tuple or tensor list manually, | ||
and remove the ListUnpack/TupleUnpack node from | ||
the graph. Note: this function will change the | ||
graph structure. | ||
""" | ||
if hasattr(self, 'unpacked'): | ||
# if already unpacked the tuple/list manually | ||
return | ||
for node in self.nodes_py.nodes_op: | ||
if node.op_type in [TUPLE_UNPACK_KIND, LIST_UNPACK_KIND]: | ||
unpack_cpp = node.key_node | ||
last_cpp = list(unpack_cpp.inputs())[0].node() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. index is 0, why call it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
if last_cpp.kind() in [TUPLE_CONSTRUCT_KIND, LIST_CONSTRUCT_KIND]: | ||
# we need check if the tensor tuple or tensor list is produced | ||
# by a list/tuple construct node. If so, we can unpack the tuple | ||
# or list manunally. | ||
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp)) | ||
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp)) | ||
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a little confused about this assert, i think the main reason is i don't understand what is |
||
for _input, _output in zip(last_cpp.inputs(), unpack_cpp.outputs()): | ||
_debug_input = _input.debugName() | ||
_debug_output = _output.debugName() | ||
if _debug_input in self.input_to_node and _debug_output in self.input_to_node: | ||
# input_to_node[_debug_input] is a list of NodePyGroup, because | ||
# one tensor can be used as input for multiple nodes at the same time. | ||
|
||
# note that, in this case, the construct cpp node and unpack cpp node | ||
# will be merged into the same NodePyGroup, so we remove the `node` from | ||
# input_to_node[_debug_input] and directly connect this tensor to the | ||
# input_to_node[_debug_output] | ||
self.input_to_node[_debug_input].remove(node) | ||
# add the following nodes of _output into the input_to_node[_debug_input] | ||
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output]) | ||
if _debug_input in self.output_to_node and _debug_output in self.output_to_node: | ||
# output_to_node[_debug_output] is a NodePyGroup, because one output | ||
# tensor only can be generated by one node. | ||
self.output_to_node[_debug_output] = self.output_to_node[_debug_input] | ||
|
||
self.unpacked = True | ||
|
||
def _build_graph(self): | ||
""" | ||
Build graph using our defined format from jit trace. | ||
|
@@ -585,13 +652,14 @@ def _build_graph(self): | |
# build node group for torch.nn.functional | ||
for _, nodes in func_to_nodes.items(): | ||
# extract non prim:: nodes | ||
non_prim_nodes = list() | ||
key_func_nodes = list() | ||
for node in nodes: | ||
if not node.kind().startswith('prim::'): | ||
non_prim_nodes.append(node) | ||
if self._is_key_func(node): | ||
# find the key function nodes | ||
key_func_nodes.append(node) | ||
# for each non prim node, expand it | ||
for node in non_prim_nodes: | ||
node_group = self._expand_non_prim_node( | ||
for node in key_func_nodes: | ||
node_group = self._expand_key_func_node( | ||
node, nodes, input_to_node, output_to_node, 'func') | ||
nodes_py.nodes_op.append(node_group) | ||
# get shape infor for view (aten::view) func | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the meaning of key node?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Key nodes are the nodes that should not be merged into other nodes. In the past, we only take the
aten::
nodes as the important(key) nodes. In this pr, we also take the list/tuple unpack nodes as the key nodes.