-
Notifications
You must be signed in to change notification settings - Fork 1.8k
fix(speedup): refactor the execution logic of functions in speedup #5107
Changes from 5 commits
6495777
90b130a
c18d76c
33946a0
74a28e6
75714e0
497650d
909c776
64c7250
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -309,8 +309,6 @@ def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node, | |
while not node_queue.empty(): | ||
curr_node = node_queue.get() | ||
for _input in curr_node.inputs(): | ||
if _input.node().kind() == CONSTANT_KIND: | ||
continue | ||
input_name = _input.debugName() | ||
if input_name in output_to_node: | ||
for predecessor_node in output_to_node[input_name]: | ||
|
@@ -693,16 +691,10 @@ def _build_graph(self): | |
input_to_node = defaultdict(list) | ||
output_to_node = defaultdict(list) | ||
for node in graph.nodes(): | ||
if node.kind() == CONSTANT_KIND: | ||
continue | ||
for x in node.outputs(): | ||
if x.node().kind() == CONSTANT_KIND: | ||
continue | ||
output_to_node[x.debugName()].append(node) | ||
assert len(output_to_node[x.debugName()]) <= 1, "One output cannot be generated by multiple nodes %s" % x.debugName() | ||
for x in node.inputs(): | ||
if x.node().kind() == CONSTANT_KIND: | ||
continue | ||
input_to_node[x.debugName()].append(node) | ||
|
||
# build module mapping, from module name to all nodes (as list) under this module scope | ||
|
@@ -725,8 +717,6 @@ def _build_graph(self): | |
|
||
# associate module name with their trace graph nodes | ||
for node in graph.nodes(): | ||
if node.kind() == CONSTANT_KIND: | ||
continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when we should skip the CONSTANT_KIND and when we should not? I found you did not remove all the |
||
module_name = self._get_module_name(node.scopeName()) | ||
if module_name in self.leaf_modules: | ||
module_to_nodes[module_name].append(node) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,6 +91,7 @@ def __init__(self, model, dummy_input, masks_file, map_location=None, | |
self.constant = {} | ||
# self.internal_result save the internal output of the submodules | ||
self.internal_result = {} | ||
self.executed_prim_nodes = set() | ||
self.customized_replace_func = customized_replace_func if customized_replace_func is not None else {} | ||
|
||
def _random_model_input(self, dummy_input, confidence, batch_dim): | ||
|
@@ -175,12 +176,15 @@ def _prepare_dummy_input(self, node): | |
# prepare the inputs and outputs mask for this node, | ||
# if there is already a mask in self.masks, then use | ||
# the original mask tensor, else create a new one. | ||
inputs_name = node.inputs | ||
if node.type == 'module': | ||
inputs_name = node.inputs | ||
else: | ||
inputs_name = [val_node.debugName() for val_node in node.key_node.inputs()] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is key_node? maybe we need a doc to explain the important attrs in a node |
||
# build the dummy_input, in_masks the target node | ||
dummy_input = [] | ||
debugnames = [] | ||
for _input in inputs_name: | ||
if _input not in self.internal_result: | ||
if _input not in self.internal_result or _input not in self.masks: | ||
# if the input debug name is not in self.internal_result, | ||
# then this node isn't a output tensor of any predecessor | ||
# nodes. This node is a attribute of the submodule, such as | ||
|
@@ -197,10 +201,25 @@ def _prepare_dummy_input(self, node): | |
continue | ||
# The detach operation here is for the in-place operation. We cannot | ||
# directly can the backward on the output tensor of an in-place operator. | ||
dummy_input.append(self.internal_result[_input].detach()) | ||
dummy_input.append(self.internal_result[_input]) | ||
|
||
debugnames.append(_input) | ||
|
||
def recr_detacher(obj): | ||
if isinstance(obj, torch.Tensor): | ||
return obj.detach() | ||
elif isinstance(obj, tuple): | ||
return tuple([recr_detacher(i) for i in obj]) | ||
elif isinstance(obj, list): | ||
return [recr_detacher(i) for i in obj] | ||
elif isinstance(obj, set): | ||
return set([recr_detacher(i) for i in obj]) | ||
elif isinstance(obj, dict): | ||
return {k: recr_detacher(v) for k, v in obj.items()} | ||
else: | ||
return obj | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can obj a customized data type with
|
||
|
||
dummy_input = recr_detacher(dummy_input) | ||
return dummy_input, debugnames | ||
|
||
def update_direct_sparsity(self, node): | ||
|
@@ -213,6 +232,12 @@ def update_direct_sparsity(self, node): | |
module_name = node.name | ||
_logger.info('Update mask for %s', module_name) | ||
unique_name = node.unique_name | ||
if node.type == 'func': | ||
func = jit_to_python_function(node, self) | ||
if func is None: | ||
# no need to infer the sparsity for this node | ||
self.auto_inferences[unique_name] = None | ||
return | ||
dummy_input, input_debugname = self._prepare_dummy_input(node) | ||
# get the input mask from self.masks | ||
# Note: the input mask of the successor nodes are | ||
|
@@ -225,11 +250,7 @@ def update_direct_sparsity(self, node): | |
# graph, so we translate it back to python function, Note: the function | ||
# is appliable to both cpu/gpu devices, the output tensors will be on the | ||
# same device of the input tensors | ||
func = jit_to_python_function(node, self) | ||
if func is None: | ||
# no need to infer the sparsity for this node | ||
self.auto_inferences[unique_name] = None | ||
return | ||
|
||
# function doesn't have weights | ||
_auto_infer = AutoMaskInference( | ||
func, dummy_input, self, in_masks, in_constants=in_constants) | ||
|
@@ -297,16 +318,15 @@ def update_indirect_sparsity(self, node): | |
debug_name = auto_infer.input_debugname[in_id] | ||
|
||
last_output = self.internal_result[debug_name] | ||
# if isinstance(last_output, torch.Tensor): | ||
# TODO what if last output is tuple/list of tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so |
||
if last_output.grad is not None and tin.grad is not None: | ||
last_output.grad.data += tin.grad.data | ||
elif last_output.grad is None: | ||
last_output.grad = tin.grad | ||
elif last_output.grad is not None and tin.grad is None: | ||
# for example, tin.view(batch, tin.size(1)/2, tin.view(2)*2) | ||
# the size operation of tin will have no gradient | ||
continue | ||
if isinstance(last_output, torch.Tensor): | ||
if last_output.grad is not None and tin.grad is not None: | ||
last_output.grad.data += tin.grad.data | ||
elif last_output.grad is None: | ||
last_output.grad = tin.grad | ||
elif last_output.grad is not None and tin.grad is None: | ||
# for example, tin.view(batch, tin.size(1)/2, tin.view(2)*2) | ||
# the size operation of tin will have no gradient | ||
continue | ||
else: | ||
_logger.warning( | ||
'Note: %s does not have corresponding mask inference object', node.name) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
?