Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix(speedup): refactor the execution logic of functions in speedup #5107

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/compression/pruning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Coarse-grained pruning or structured pruning is pruning a regular group of weigh

Only :ref:`level-pruner` and :ref:`admm-pruner` support fine-grained pruning, all other pruners do some kind of structured pruning on weights.

.. _dependency-awareode-for-output-channel-pruning:
.. _dependency-aware-mode-for-output-channel-pruning:

Dependency-aware Mode for Output Channel Pruning
------------------------------------------------
Expand Down Expand Up @@ -105,4 +105,5 @@ In addition, for the convolutional layers that have more than one filter group,
``dependency-aware pruner`` will also try to prune the same number of the channels for each filter group.
Overall, this pruner will prune the model according to the L1 norm of each filter and try to meet the topological constrains (channel dependency, etc) to improve the final speed gain after the speedup process.

Operations that will be recognized as having channel dependencies: add/sub/mul/div, addcmul/addcdiv, logical_and/or/xor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.?

In the dependency-aware mode, the pruner will provide a better speed gain from the model pruning.
10 changes: 0 additions & 10 deletions nni/common/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,6 @@ def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node,
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
if _input.node().kind() == CONSTANT_KIND:
continue
input_name = _input.debugName()
if input_name in output_to_node:
for predecessor_node in output_to_node[input_name]:
Expand Down Expand Up @@ -693,16 +691,10 @@ def _build_graph(self):
input_to_node = defaultdict(list)
output_to_node = defaultdict(list)
for node in graph.nodes():
if node.kind() == CONSTANT_KIND:
continue
for x in node.outputs():
if x.node().kind() == CONSTANT_KIND:
continue
output_to_node[x.debugName()].append(node)
assert len(output_to_node[x.debugName()]) <= 1, "One output cannot be generated by multiple nodes %s" % x.debugName()
for x in node.inputs():
if x.node().kind() == CONSTANT_KIND:
continue
input_to_node[x.debugName()].append(node)

# build module mapping, from module name to all nodes (as list) under this module scope
Expand All @@ -725,8 +717,6 @@ def _build_graph(self):

# associate module name with their trace graph nodes
for node in graph.nodes():
if node.kind() == CONSTANT_KIND:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we should skip the CONSTANT_KIND and when we should not? I found you did not remove all the if node.kind() == CONSTANT_KIND: in the code

module_name = self._get_module_name(node.scopeName())
if module_name in self.leaf_modules:
module_to_nodes[module_name].append(node)
Expand Down
56 changes: 38 additions & 18 deletions nni/compression/pytorch/speedup/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, model, dummy_input, masks_file, map_location=None,
self.constant = {}
# self.internal_result save the internal output of the submodules
self.internal_result = {}
self.executed_prim_nodes = set()
self.customized_replace_func = customized_replace_func if customized_replace_func is not None else {}

def _random_model_input(self, dummy_input, confidence, batch_dim):
Expand Down Expand Up @@ -175,12 +176,15 @@ def _prepare_dummy_input(self, node):
# prepare the inputs and outputs mask for this node,
# if there is already a mask in self.masks, then use
# the original mask tensor, else create a new one.
inputs_name = node.inputs
if node.type == 'module':
inputs_name = node.inputs
else:
inputs_name = [val_node.debugName() for val_node in node.key_node.inputs()]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is key_node? maybe we need a doc to explain the important attrs in a node

# build the dummy_input, in_masks the target node
dummy_input = []
debugnames = []
for _input in inputs_name:
if _input not in self.internal_result:
if _input not in self.internal_result or _input not in self.masks:
# if the input debug name is not in self.internal_result,
# then this node isn't a output tensor of any predecessor
# nodes. This node is a attribute of the submodule, such as
Expand All @@ -197,10 +201,25 @@ def _prepare_dummy_input(self, node):
continue
# The detach operation here is for the in-place operation. We cannot
# directly can the backward on the output tensor of an in-place operator.
dummy_input.append(self.internal_result[_input].detach())
dummy_input.append(self.internal_result[_input])

debugnames.append(_input)

def recr_detacher(obj):
if isinstance(obj, torch.Tensor):
return obj.detach()
elif isinstance(obj, tuple):
return tuple([recr_detacher(i) for i in obj])
elif isinstance(obj, list):
return [recr_detacher(i) for i in obj]
elif isinstance(obj, set):
return set([recr_detacher(i) for i in obj])
elif isinstance(obj, dict):
return {k: recr_detacher(v) for k, v in obj.items()}
else:
return obj
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can obj a customized data type with detach function?

try:
    return obj.detach()
except AttributeError:
    return obj


dummy_input = recr_detacher(dummy_input)
return dummy_input, debugnames

def update_direct_sparsity(self, node):
Expand All @@ -213,6 +232,12 @@ def update_direct_sparsity(self, node):
module_name = node.name
_logger.info('Update mask for %s', module_name)
unique_name = node.unique_name
if node.type == 'func':
func = jit_to_python_function(node, self)
if func is None:
# no need to infer the sparsity for this node
self.auto_inferences[unique_name] = None
return
dummy_input, input_debugname = self._prepare_dummy_input(node)
# get the input mask from self.masks
# Note: the input mask of the successor nodes are
Expand All @@ -225,11 +250,7 @@ def update_direct_sparsity(self, node):
# graph, so we translate it back to python function, Note: the function
# is appliable to both cpu/gpu devices, the output tensors will be on the
# same device of the input tensors
func = jit_to_python_function(node, self)
if func is None:
# no need to infer the sparsity for this node
self.auto_inferences[unique_name] = None
return

# function doesn't have weights
_auto_infer = AutoMaskInference(
func, dummy_input, self, in_masks, in_constants=in_constants)
Expand Down Expand Up @@ -297,16 +318,15 @@ def update_indirect_sparsity(self, node):
debug_name = auto_infer.input_debugname[in_id]

last_output = self.internal_result[debug_name]
# if isinstance(last_output, torch.Tensor):
# TODO what if last output is tuple/list of tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what if last output is tuple/list of tensor?

if last_output.grad is not None and tin.grad is not None:
last_output.grad.data += tin.grad.data
elif last_output.grad is None:
last_output.grad = tin.grad
elif last_output.grad is not None and tin.grad is None:
# for example, tin.view(batch, tin.size(1)/2, tin.view(2)*2)
# the size operation of tin will have no gradient
continue
if isinstance(last_output, torch.Tensor):
if last_output.grad is not None and tin.grad is not None:
last_output.grad.data += tin.grad.data
elif last_output.grad is None:
last_output.grad = tin.grad
elif last_output.grad is not None and tin.grad is None:
# for example, tin.view(batch, tin.size(1)/2, tin.view(2)*2)
# the size operation of tin will have no gradient
continue
else:
_logger.warning(
'Note: %s does not have corresponding mask inference object', node.name)
Expand Down
9 changes: 5 additions & 4 deletions nni/compression/pytorch/speedup/infer_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,16 @@ def update_indirect_sparsity(self):
x, torch.Tensor) else x for x in self.dummy_input]
output = self.module(*tmp_dummy_input)

if output.grad_fn is None:
# the output does not have the gradient function
return
# Note: output maybe tensor or list/tuple of tensors
if isinstance(output, torch.Tensor):
if output.grad_fn is None:
# the output does not have the gradient function
return
output.backward(self.output_mask)
elif isinstance(output, list) or isinstance(output, tuple):
for tid, t_out in enumerate(output):
t_out.backward(self.output_mask[tid])
if t_out.grad_fn is not None:
t_out.backward(self.output_mask[tid])

# update the sparsity of the paramters
for para_name in self.weights:
Expand Down
Loading