From 7eedec46aab3593d387376b56903a72ee3a3a08d Mon Sep 17 00:00:00 2001 From: Ningxin Zheng <49771382+zheng-ningxin@users.noreply.github.com> Date: Wed, 14 Jul 2021 14:24:40 +0800 Subject: [PATCH] Model Speedup Refactor (#3462) --- .../Compression/CompressionReference.rst | 3 - nni/common/graph_utils.py | 14 +- .../pytorch/speedup/compress_modules.py | 595 ++++++--- nni/compression/pytorch/speedup/compressor.py | 579 +++++++-- nni/compression/pytorch/speedup/infer_mask.py | 378 ++++++ .../pytorch/speedup/infer_shape.py | 1146 ----------------- .../pytorch/speedup/jit_translate.py | 553 ++++++++ nni/compression/pytorch/utils/__init__.py | 1 + .../pytorch/utils/mask_conflict.py | 140 +- .../pytorch/utils/shape_dependency.py | 276 ++-- nni/compression/pytorch/utils/utils.py | 52 + test/ut/sdk/test_compression_utils.py | 2 +- test/ut/sdk/test_model_speedup.py | 180 ++- 13 files changed, 2218 insertions(+), 1701 deletions(-) create mode 100644 nni/compression/pytorch/speedup/infer_mask.py delete mode 100644 nni/compression/pytorch/speedup/infer_shape.py create mode 100644 nni/compression/pytorch/speedup/jit_translate.py diff --git a/docs/en_US/Compression/CompressionReference.rst b/docs/en_US/Compression/CompressionReference.rst index 50dcc12876..5903500115 100644 --- a/docs/en_US/Compression/CompressionReference.rst +++ b/docs/en_US/Compression/CompressionReference.rst @@ -140,9 +140,6 @@ Topology Utilities .. autoclass:: nni.compression.pytorch.utils.shape_dependency.GroupDependency :members: -.. autoclass:: nni.compression.pytorch.utils.mask_conflict.CatMaskPadding - :members: - .. autoclass:: nni.compression.pytorch.utils.mask_conflict.GroupMaskConflict :members: diff --git a/nni/common/graph_utils.py b/nni/common/graph_utils.py index fe6c68cbc0..1c1f4aaf91 100644 --- a/nni/common/graph_utils.py +++ b/nni/common/graph_utils.py @@ -71,7 +71,11 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): def _trace(self, model, dummy_input): training = model.training model.eval() - self.trace = torch.jit.trace(model, dummy_input) + kw_args = {} + if torch.__version__ >= '1.6.0': + # only pytorch with version greater than 1.6.0 has the strict option + kw_args['strict'] = False + self.trace = torch.jit.trace(model, dummy_input, **kw_args) torch._C._jit_pass_inline(self.trace.graph) model.train(training) @@ -247,6 +251,7 @@ class TorchModuleGraph(TorchGraph): def __init__(self, model=None, dummy_input=None, traced_model=None): super().__init__(model, dummy_input, traced_model) self.global_count = 0 + self.reused_module = set() self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() self._extract_auxiliary_info() @@ -390,9 +395,12 @@ def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, outputs.append(output_name) else: outputs.append(output_name) + unique_outputs = list(set(outputs)) + # remove the dumplicated output names + unique_outputs.sort(key=outputs.index) nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, - node_group, inputs=list(inputs), outputs=list(outputs)) + node_group, inputs=list(inputs), outputs=unique_outputs) return nodepy def _extract_cat_info(self, node_group, cpp_node): @@ -724,6 +732,8 @@ def _build_graph(self): unique_name = module_name if use_count > 0: unique_name = module_name + '.%d' % use_count + self.reused_module.add(unique_name) + self.reused_module.add(module_name) node_group = self._expand_module_node( node, module_name, unique_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node, 'module') diff --git a/nni/compression/pytorch/speedup/compress_modules.py b/nni/compression/pytorch/speedup/compress_modules.py index c382c0b7e2..c4a922b9ed 100644 --- a/nni/compression/pytorch/speedup/compress_modules.py +++ b/nni/compression/pytorch/speedup/compress_modules.py @@ -3,222 +3,394 @@ import logging import torch -from .infer_shape import ModuleMasks +import torch.nn as nn _logger = logging.getLogger(__name__) replace_module = { - 'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask), - 'Conv2d': lambda module, mask: replace_conv2d(module, mask), - 'ConvTranspose2d': lambda module, mask: replace_convtranspose2d(module, mask), - 'MaxPool2d': lambda module, mask: no_replace(module, mask), - 'AvgPool2d': lambda module, mask: no_replace(module, mask), - 'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask), - 'ReLU': lambda module, mask: no_replace(module, mask), - 'PReLU': lambda module, mask: replace_prelu(module, mask), - 'ReLU6': lambda module, mask: no_replace(module, mask), - 'Sigmoid': lambda module, mask: no_replace(module, mask), - 'Linear': lambda module, mask: replace_linear(module, mask), - 'Dropout': lambda module, mask: no_replace(module, mask), - 'Dropout2d': lambda module, mask: no_replace(module, mask), - 'Dropout3d': lambda module, mask: no_replace(module, mask) + 'BatchNorm2d': lambda module, masks: replace_batchnorm2d(module, masks), + 'BatchNorm1d': lambda module, masks: replace_batchnorm1d(module, masks), + 'Conv2d': lambda module, masks: replace_conv2d(module, masks), + 'Linear': lambda module, masks: replace_linear(module, masks), + 'MaxPool2d': lambda module, masks: no_replace(module, masks), + 'AvgPool2d': lambda module, masks: no_replace(module, masks), + 'AdaptiveAvgPool2d': lambda module, masks: no_replace(module, masks), + 'ReLU': lambda module, masks: no_replace(module, masks), + 'ReLU6': lambda module, masks: no_replace(module, masks), + 'LeakyReLU': lambda module, masks: no_replace(module, masks), + 'ELU': lambda module, masks: no_replace(module, masks), + 'Hardtanh': lambda module, masks: no_replace(module, masks), + 'Hardsigmoid': lambda module, masks: no_replace(module, masks), + 'LogSigmoid': lambda module, masks: no_replace(module, masks), + 'PReLU': lambda module, masks: replace_prelu(module, masks), + 'RReLU': lambda module, masks: no_replace(module, masks), + 'SELU': lambda module, masks: no_replace(module, masks), + 'CELU': lambda module, masks: no_replace(module, masks), + 'GELU': lambda module, masks: no_replace(module, masks), + 'Sigmoid': lambda module, masks: no_replace(module, masks), + 'SiLU': lambda module, masks: no_replace(module, masks), + 'Mish': lambda module, masks: no_replace(module, masks), + 'Tanh': lambda module, masks: no_replace(module, masks), + 'Softplus': lambda module, masks: no_replace(module, masks), + 'Softshrink': lambda module, masks: no_replace(module, masks), + 'Softmax': lambda module, masks: no_replace(module, masks), + 'Tanhshrink': lambda module, masks: no_replace(module, masks), + 'Dropout': lambda module, masks: no_replace(module, masks), + 'Dropout2d': lambda module, masks: no_replace(module, masks), + 'Dropout3d': lambda module, masks: no_replace(module, masks), + 'Upsample': lambda module, masks: no_replace(module, masks), + 'LayerNorm': lambda module, masks: replace_layernorm(module, masks), + 'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks) } -def no_replace(module, mask): + +def convert_to_coarse_mask(t_mask, dim): + """ + Convert the mask tensor to the coarse-grained mask tensor. + Parameters + --------- + t_mask: torch.Tensor + The tensor only have 1s and 0s, 0 indicates this value is masked + and 1 indicates the corresponding value is not masked. + dim: int + Try to reduce the mask tensor on this dimension. + + Returns + ------- + indexes: torch.Tensor + The indexes of the sparsity that can be structurally removed. + remained_indexes: torch.Tensor + The indexes of values that need to be remained. + """ + assert isinstance(t_mask, torch.Tensor) + shape = list(t_mask.size()) + n_dims = len(shape) + dim_list = list(range(n_dims)) + # try to reduce the mask from the dim-th dimension + dim_list.remove(dim) + + t_merged = torch.sum(t_mask, dim_list) + assert t_merged.size(0) == shape[dim] + all_pruned = t_merged == 0 + need_remain = t_merged != 0 + # return the indexes of the sparsity that can be removed + indexes = torch.nonzero(all_pruned, as_tuple=True)[0] + remained_indexes = torch.nonzero(need_remain, as_tuple=True)[0] + return indexes, remained_indexes + + +def no_replace(module, masks): """ No need to replace """ _logger.debug("no need to replace") return module -def replace_prelu(norm, mask): +def replace_prelu(prelu, masks): """ Parameters ---------- - norm : torch.nn.BatchNorm2d + module : torch.nn.PReLU The prelu module to be replace - mask : ModuleMasks - The masks of this module + masks : tuple of masks + The input/output/weight masks of the target module Returns ------- torch.nn.PReLU The new prelu module """ - assert isinstance(mask, ModuleMasks) - assert 'weight' in mask.param_masks - index = mask.param_masks['weight'].mask_index[0] - num_features = index.size()[0] - # _logger.debug("replace prelu with num_features: %d", num_features) - if num_features == 0: + in_masks, output_mask, weight_mask = masks + assert len(in_masks) == 1 + assert isinstance(output_mask, torch.Tensor) + in_mask = in_masks[0] + weight_mask = weight_mask['weight'] + pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1) + pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1) + n_remained_in = weight_mask.size(0) - pruned_in.size(0) + n_remained_out = weight_mask.size(0) - pruned_out.size(0) + remained_in, remained_out = remained_in.to( + prelu.weight.device), remained_out.to(prelu.weight.device) + assert n_remained_in == n_remained_out + if n_remained_in == 0: return torch.nn.Identity() - new_norm = torch.nn.PReLU(num_features) - # assign weights - new_norm.weight.data = torch.index_select(norm.weight.data, 0, index) - return new_norm + new_prelu = torch.nn.PReLU(n_remained_in) + new_prelu.weight.data = torch.index_select(prelu.weight.data, 0, remained_in) + return new_prelu -def replace_linear(linear, mask): +def replace_linear(linear, masks): """ + This function will replace the original linear according to + the infered masks. This function support the fine-grained and + coarse-grained sparsity. In the fine-grained scenario, this function + will remove the whole column/row that happen to be totally covered by + the masks. + Parameters ---------- linear : torch.nn.Linear The linear module to be replace - mask : ModuleMasks - The masks of this module + masks : Tuple of the input masks, output masks and weight masks + Tuple of the masks, for example + ([input_m1, input_m2], [output_m], {'weight':weight_m}) Returns ------- torch.nn.Linear The new linear module """ - assert isinstance(mask, ModuleMasks) - assert mask.input_mask is not None - assert mask.output_mask is None - assert not mask.param_masks - index = mask.input_mask.mask_index[-1] - in_features = index.size()[0] - _logger.debug("replace linear with new in_features: %d", in_features) - new_linear = torch.nn.Linear(in_features=in_features, - out_features=linear.out_features, - bias=linear.bias is not None) - new_linear.to(linear.weight.device) - new_linear.weight.data = torch.index_select( - linear.weight.data, -1, index.to(linear.weight.device)) + in_masks, output_mask, weight_mask = masks + assert isinstance(linear, nn.Linear) + assert len(in_masks) == 1 + assert isinstance(output_mask, torch.Tensor) + in_mask = in_masks[0] + + weight_mask = weight_mask['weight'] + # N C K + pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1) + pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1) + n_remained_in = weight_mask.size(1) - pruned_in.size(0) + n_remained_out = weight_mask.size(0) - pruned_out.size(0) + remained_in, remained_out = remained_in.to( + linear.weight.device), remained_out.to(linear.weight.device) + _logger.info("replace linear with new in_features: %d, out_features: %d", + n_remained_in, n_remained_out) + need_bias = False if linear.bias is not None: - new_linear.bias.data.copy_(linear.bias.data) + need_bias = True + new_linear = torch.nn.Linear(in_features=n_remained_in, + out_features=n_remained_out, + bias=need_bias) + new_linear.to(linear.weight.device) + # Copy the remained weight from the original module + with torch.no_grad(): + tmp_weight_data = torch.index_select( + linear.weight.data, 0, remained_out) + new_linear.weight.data = torch.index_select( + tmp_weight_data, 1, remained_in) + + if linear.bias is not None: + new_linear.bias.data = torch.index_select( + linear.bias.data, 0, remained_out) + return new_linear -def replace_batchnorm2d(norm, mask): +def replace_batchnorm1d(norm, masks): + """ + Parameters + ---------- + norm : torch.nn.BatchNorm1d + The batchnorm module to be replace + masks : Tuple of the input masks, output masks and weight masks + Tuple of the masks, for example + ([input_m1, input_m2], [output_m], {'weight':weight_m}) + + Returns + ------- + torch.nn.BatchNorm1d + The new batchnorm module + """ + in_masks, output_mask, _ = masks + assert isinstance(norm, nn.BatchNorm1d) + in_mask = in_masks[0] + + # N, C, H, W + _, remained_in = convert_to_coarse_mask(in_mask, 1) + _, remained_out = convert_to_coarse_mask(output_mask, 1) + assert remained_in.size(0) == remained_out.size(0) + + num_features = remained_in.size(0) + _logger.info("replace batchnorm1d with num_features: %d", num_features) + new_norm = torch.nn.BatchNorm1d(num_features=num_features, + eps=norm.eps, + momentum=norm.momentum, + affine=norm.affine, + track_running_stats=norm.track_running_stats) + # assign weights + new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in) + new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in) + + new_norm.running_mean.data = torch.index_select( + norm.running_mean.data, 0, remained_in) + new_norm.running_var.data = torch.index_select( + norm.running_var.data, 0, remained_in) + return new_norm + + +def replace_batchnorm2d(norm, masks): """ Parameters ---------- norm : torch.nn.BatchNorm2d The batchnorm module to be replace - mask : ModuleMasks - The masks of this module + masks : Tuple of the input masks, output masks and weight masks + Tuple of the masks, for example + ([input_m1, input_m2], [output_m], {'weight':weight_m}) Returns ------- torch.nn.BatchNorm2d The new batchnorm module """ - assert isinstance(mask, ModuleMasks) - assert 'weight' in mask.param_masks and 'bias' in mask.param_masks - index = mask.param_masks['weight'].mask_index[0] - num_features = index.size()[0] - _logger.debug("replace batchnorm2d with num_features: %d", num_features) + in_masks, output_mask, _ = masks + assert isinstance(norm, nn.BatchNorm2d) + in_mask = in_masks[0] + + # N, C, H, W + _, remained_in = convert_to_coarse_mask(in_mask, 1) + _, remained_out = convert_to_coarse_mask(output_mask, 1) + assert remained_in.size(0) == remained_out.size(0) + + num_features = remained_in.size(0) + _logger.info("replace batchnorm2d with num_features: %d", num_features) new_norm = torch.nn.BatchNorm2d(num_features=num_features, eps=norm.eps, momentum=norm.momentum, affine=norm.affine, track_running_stats=norm.track_running_stats) # assign weights - new_norm.weight.data = torch.index_select(norm.weight.data, 0, index) - new_norm.bias.data = torch.index_select(norm.bias.data, 0, index) - if norm.track_running_stats: - new_norm.running_mean.data = torch.index_select( - norm.running_mean.data, 0, index) - new_norm.running_var.data = torch.index_select( - norm.running_var.data, 0, index) + new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in) + new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in) + + new_norm.running_mean.data = torch.index_select( + norm.running_mean.data, 0, remained_in) + new_norm.running_var.data = torch.index_select( + norm.running_var.data, 0, remained_in) return new_norm -def replace_conv2d(conv, mask): + +def replace_conv2d(conv, masks): """ + Replace the original conv with a new one according to the infered + masks, the function support the fine-grained sparsity and coarse-grained + sparsity. In the fine-grained scenario, this replace function will replace + the filters that happen to be totally coverd by the fine-grained sparsity. + Parameters ---------- conv : torch.nn.Conv2d The conv2d module to be replaced - mask : ModuleMasks - The masks of this module + masks : Tuple of the input masks, output masks and weight masks + Tuple of the masks, for example + ([input_m1, input_m2], [output_m], {'weight':weight_m}) Returns ------- torch.nn.Conv2d The new conv2d module """ - assert isinstance(mask, ModuleMasks) - if mask.input_mask is None: - in_channels = conv.in_channels - else: - in_channels_index = mask.input_mask.mask_index[1] - in_channels = in_channels_index.size()[0] - if mask.output_mask is None: - out_channels = conv.out_channels - else: - out_channels_index = mask.output_mask.mask_index[1] - out_channels = out_channels_index.size()[0] - groups = conv.groups - if conv.in_channels == conv.out_channels == conv.groups: - # remove groups for depthwise layers - assert in_channels == out_channels - groups = in_channels - _logger.debug("replace conv2d %s with in_channels: %d, out_channels: %d", - mask.module_name, in_channels, out_channels) - new_conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, + in_masks, output_mask, weight_masks = masks + assert isinstance(conv, nn.Conv2d) + # the conv layer should only have one input tensor + assert len(in_masks) == 1 + + in_mask = in_masks[0] + + weight_mask = weight_masks['weight'] + pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1) + pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1) + + n_remained_in = weight_mask.size(1) * conv.groups - pruned_in.size(0) + n_remained_out = weight_mask.size(0) - pruned_out.size(0) + + assert n_remained_in == remained_in.size(0) + assert n_remained_out == remained_out.size(0) + + k_size1, k_size2 = conv.kernel_size + # Note: We should resolve the group dependency of the conv layers before + # run into here. + # check if the mask tensor meets the group dependency and calculate the + # new number of the groups after pruning + # the original step size of the input channel for each group + ori_inchannel_step = int(conv.in_channels/conv.groups) + # the original step size of the output channel for each group + ori_outchannel_step = int(conv.out_channels/conv.groups) + # calculate the new_in_channel_step and new_outchannel_step first + new_inchannel_step = new_outchannel_step = None + for groupid in range(conv.groups): + in_start = groupid * ori_inchannel_step + in_end = in_start + ori_inchannel_step + out_start = groupid * ori_outchannel_step + out_end = out_start + ori_outchannel_step + current_input_index = list( + filter(lambda x: in_start <= x and x < in_end, remained_in.tolist())) + current_output_index = list( + filter(lambda x: out_start <= x and x < out_end, remained_out.tolist())) + # remap the global index to the group index + if len(current_input_index) == 0: + # if the whole group are pruned + continue + else: + + new_inchannel_step = len(current_input_index) + new_outchannel_step = len(current_output_index) + break + tmp_weight = torch.ones( + n_remained_out, new_inchannel_step, k_size1, k_size2) + tmp_weight = tmp_weight.to(conv.weight.device) + + assert n_remained_in % new_inchannel_step == 0 + assert n_remained_out % new_outchannel_step == 0 + + new_groups = 0 + for groupid in range(conv.groups): + in_start = groupid * ori_inchannel_step + in_end = in_start + ori_inchannel_step + out_start = groupid * ori_outchannel_step + out_end = out_start + ori_outchannel_step + current_input_index = list( + filter(lambda x: in_start <= x and x < in_end, remained_in.tolist())) + current_output_index = list( + filter(lambda x: out_start <= x and x < out_end, remained_out.tolist())) + # remap the global index to the group index + current_input_index = [x-in_start for x in current_input_index] + if len(current_input_index) == 0: + # if the whole group are pruned + assert len(current_output_index) == 0 + continue + # check if the number of remained channel of each group are the same + assert len(current_input_index) == new_inchannel_step + assert len(current_output_index) == new_outchannel_step + # copy the weight into tmp_weight + new_out_start = new_outchannel_step * new_groups + new_out_end = new_out_start + new_outchannel_step + tmp_weight[new_out_start:new_out_end] = torch.index_select( + conv.weight[current_output_index], 1, torch.as_tensor(current_input_index, dtype=torch.long).to(conv.weight.device)) + new_groups += 1 + + _logger.debug("replace conv2d with in_channels: %d, out_channels: %d", + n_remained_in, n_remained_out) + + # need_bias is a flag that indicates that if a conv layer need + # bias, if the original conv doesn't have a bias and there is + # no constant need to be folded into the bias, the need_bias is False. + need_bias = conv.bias is not None + new_conv = torch.nn.Conv2d(in_channels=n_remained_in, + out_channels=n_remained_out, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, - groups=groups, - bias=conv.bias is not None, + groups=new_groups, + bias=need_bias, padding_mode=conv.padding_mode) new_conv.to(conv.weight.device) - tmp_weight_data = tmp_bias_data = None - - if mask.output_mask is not None: - tmp_weight_data = torch.index_select( - conv.weight.data, 0, out_channels_index) - if conv.bias is not None: - tmp_bias_data = torch.index_select( - conv.bias.data, 0, out_channels_index) - else: - tmp_weight_data = conv.weight.data - # For the convolutional layers that have more than one group - # we need to copy the weight group by group, because the input - # channal is also divided into serveral groups and each group - # filter may have different input channel indexes. - input_step = int(conv.in_channels / conv.groups) - in_channels_group = int(in_channels / groups) - filter_step = int(out_channels / groups) - if mask.input_mask is not None and not (in_channels == out_channels == groups): - for groupid in range(conv.groups): - start = groupid * input_step - end = (groupid + 1) * input_step - current_input_index = list( - filter(lambda x: start <= x and x < end, in_channels_index.tolist())) - if not current_input_index: - # there is no kept channel in current group - # TODO bug here, the groups is directly get from conv.groups, if the whole group is removed, - # then the number of groups in the new_conv also need to change - raise Exception( - " Donnot support removing the whole group filter except in the depth-wise conv temporarily") - # shift the global index into the group index - current_input_index = [x-start for x in current_input_index] - # if the groups is larger than 1, the input channels of each - # group should be pruned evenly. - assert len(current_input_index) == in_channels_group, \ - 'Input channels of each group are not pruned evenly' - current_input_index = torch.tensor(current_input_index).to(tmp_weight_data.device) # pylint: disable=not-callable - f_start = groupid * filter_step - f_end = (groupid + 1) * filter_step - new_conv.weight.data[f_start:f_end] = torch.index_select( - tmp_weight_data[f_start:f_end], 1, current_input_index) - else: - new_conv.weight.data.copy_(tmp_weight_data) + new_conv.weight.copy_(tmp_weight) + # copy the bias data if conv.bias is not None: - new_conv.bias.data.copy_( - conv.bias.data if tmp_bias_data is None else tmp_bias_data) + new_conv.bias.data.copy_(torch.index_select( + conv.bias.data, 0, remained_out)) + return new_conv -def replace_convtranspose2d(convtrans, mask): +def replace_convtranspose2d(convtrans, masks): """ We need anothor replace function for convtranspose2d, because the layout of @@ -228,81 +400,120 @@ def replace_convtranspose2d(convtrans, mask): ---------- convtrans : torch.nn.ConvTranspose2d The conv2d module to be replaced - mask : ModuleMasks - The masks of this module + masks : Tuple of the input masks, output masks and weight masks + Tuple of the masks, for example + ([input_m1, input_m2], [output_m], {'weight':weight_m}) Returns ------- torch.nn.ConvTranspose2d The new conv2d module """ - assert isinstance(mask, ModuleMasks) + in_masks, output_mask, weight_masks = masks assert isinstance(convtrans, torch.nn.ConvTranspose2d) - if mask.input_mask is None: - in_channels = convtrans.in_channels - else: - in_channels_index = mask.input_mask.mask_index[1] - in_channels = in_channels_index.size(0) - if mask.output_mask is None: - out_channels = convtrans.out_channels - else: - out_channels_index = mask.output_mask.mask_index[1] - out_channels = out_channels_index.size(0) - groups = convtrans.groups - # check if can remove the whole group of filters - if convtrans.in_channels == convtrans.out_channels == convtrans.groups: - # remove groups for depthwise layers - # this needs the group dependency to be fixed before the speedup - assert in_channels == out_channels - groups = in_channels - _logger.debug('Replace convtranspose2d %s with in_channels:%d out_channels:%d', - mask.module_name, in_channels, out_channels) - new_convtrans = torch.nn.ConvTranspose2d(in_channels=in_channels, - out_channels=out_channels, + assert len(in_masks) == 1 + in_mask = in_masks[0] + + weight_mask = weight_masks['weight'] + pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1) + pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1) + # ConvTranspose2d has the weight shape of [N_in, N_out/groups, k1, k2] + n_remained_in = weight_mask.size(0) - pruned_in.size(0) + n_remained_out = weight_mask.size( + 1) * convtrans.groups - pruned_out.size(0) + assert n_remained_in == remained_in.size(0) + assert n_remained_out == remained_out.size(0) + k_size1, k_size2 = convtrans.kernel_size + # Note: we should resolve the group dependency of the convtrans layers before + # run into this function + ori_inchannel_step = int(convtrans.in_channels/convtrans.groups) + ori_outchannel_step = int(convtrans.out_channels/convtrans.groups) + new_inchannel_step = new_outchannel_step = None + for groupid in range(convtrans.groups): + in_start = groupid * ori_inchannel_step + in_end = in_start + ori_inchannel_step + out_start = groupid * ori_outchannel_step + out_end = out_start + ori_outchannel_step + current_input_index = list( + filter(lambda x: in_start <= x and x < in_end, remained_in.tolist())) + current_output_index = list( + filter(lambda x: out_start <= x and x < out_end, remained_out.tolist())) + if len(current_input_index) == 0: + # if the whole group are pruned + continue + else: + new_inchannel_step = len(current_input_index) + new_outchannel_step = len(current_output_index) + break + tmp_weight = torch.ones( + n_remained_in, new_outchannel_step, k_size1, k_size2) + tmp_weight = tmp_weight.to(convtrans.weight.device) + + assert n_remained_in % new_inchannel_step == 0 + assert n_remained_out % new_outchannel_step == 0 + + new_groups = 0 + for groupid in range(convtrans.groups): + # copy the weights of this group + in_start = groupid * ori_inchannel_step + in_end = in_start + ori_inchannel_step + out_start = groupid * ori_outchannel_step + out_end = out_start + ori_outchannel_step + current_input_index = list( + filter(lambda x: in_start <= x and x < in_end, remained_in.tolist())) + current_output_index = list( + filter(lambda x: out_start <= x and x < out_end, remained_out.tolist())) + # remap the global index to the group index + # in the convtranspose layer, the groups are on + # the output channel dimension + current_output_index = [x-out_start for x in current_output_index] + if len(current_input_index) == 0: + # if the whole group are pruned + assert len(current_output_index) == 0 + continue + # check if the number of remained channel of each group are the same + assert len(current_input_index) == new_inchannel_step + assert len(current_output_index) == new_outchannel_step + # copy the weight into tmp_weight + new_in_start = new_inchannel_step * new_groups + new_in_end = new_in_start + new_inchannel_step + tmp_weight[new_in_start:new_in_end] = torch.index_select( + convtrans.weight[current_input_index], 1, torch.as_tensor(current_output_index, dtype=torch.long).to(convtrans.weight.device)) + new_groups += 1 + + _logger.debug('Replace convtranspose2d with in_channels:%d out_channels:%d', + n_remained_in, n_remained_out) + new_convtrans = torch.nn.ConvTranspose2d(in_channels=n_remained_in, + out_channels=n_remained_out, kernel_size=convtrans.kernel_size, stride=convtrans.stride, padding=convtrans.padding, dilation=convtrans.dilation, - groups=groups, + groups=new_groups, bias=convtrans.bias is not None, padding_mode=convtrans.padding_mode) new_convtrans.to(convtrans.weight.device) - tmp_weight_data = None - if mask.input_mask is not None: - # in convtranspose2d we need to select the input channel first - tmp_weight_data = torch.index_select( - convtrans.weight.data, 0, in_channels_index) - else: - tmp_weight_data = convtrans.weight.data - # we need to handle the output channel group by group like the conv layer - out_step = int(convtrans.out_channels / convtrans.groups) - out_channel_group = int(out_channels/groups) - new_in_per_group = int(in_channels/groups) - - if mask.output_mask is not None and not(in_channels == out_channels == groups): - for groupid in range(convtrans.groups): - start = groupid * out_step - end = (groupid + 1) * out_step - current_output_index = list( - filter(lambda x: start <= x and x < end, out_channels_index.tolist())) - # we need to shift the index into the group-wise - current_output_index = [x-start for x in current_output_index] - if not current_output_index: - # No kept channel in the current group - raise Exception( - " Donnot support removing the whole group filter except in the depth-wise conv temporarily") - assert len(current_output_index) == out_channel_group, \ - 'Output channel of each group should be the same after pruning' - current_output_index = torch.tensor(current_output_index).to(tmp_weight_data.device) # pylint: disable=not-callable - new_start = groupid * new_in_per_group - new_end = (groupid + 1) * new_in_per_group - new_convtrans.weight.data[new_start:new_end] = torch.index_select( - tmp_weight_data[new_start:new_end], 1, current_output_index) - else: - new_convtrans.weight.data.copy_(tmp_weight_data) + new_convtrans.weight.copy_(tmp_weight) if convtrans.bias is not None: - if mask.output_mask is not None: + if output_mask is not None: new_convtrans.bias.data[:] = torch.index_select( - convtrans.bias.data, 0, out_channels_index) + convtrans.bias.data, 0, remained_out) else: new_convtrans.bias.data.copy_(convtrans.bias.data) return new_convtrans + + +def replace_layernorm(layernorm, masks): + in_masks, _, _ = masks + assert isinstance(layernorm, nn.LayerNorm) + assert len(in_masks) == 1 + in_mask = in_masks[0] + dim_n = len(in_mask.size()) + new_shape = [] + for i in range(1, dim_n): + sum_dims = list(range(0, dim_n)) + sum_dims.remove(i) + reduced = torch.sum(in_mask, sum_dims) + n_remained = torch.sum(reduced > 0) + new_shape.append(n_remained) + + return nn.LayerNorm(tuple(new_shape), layernorm.eps, layernorm.elementwise_affine) diff --git a/nni/compression/pytorch/speedup/compressor.py b/nni/compression/pytorch/speedup/compressor.py index 0f308bbf18..c11b363d88 100644 --- a/nni/compression/pytorch/speedup/compressor.py +++ b/nni/compression/pytorch/speedup/compressor.py @@ -1,143 +1,372 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import queue import logging +import copy import torch +import torch.nn as nn + +from nni.common.graph_utils import build_module_graph from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict from nni.compression.pytorch.utils.utils import get_module_by_name from .compress_modules import replace_module -from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape, set_conv_prune_dim +from .infer_mask import AutoMaskInference +from .jit_translate import jit_to_python_function +from ..utils import rand_like_with_shape _logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) class ModelSpeedup: """ - This class is to speedup the model with provided weight mask + This class is to speedup the model with provided weight mask. """ - def __init__(self, model, dummy_input, masks_file, map_location=None): + def __init__(self, model, dummy_input, masks_file, map_location=None, + batch_dim=0, confidence=8): """ Parameters ---------- model : pytorch model The model user wants to speed up - dummy_input : pytorch tensor - The dummy input for ```jit.trace```, users should put it on right device before pass in + dummy_input : pytorch tensor, tuple of tensor, list of tensor + Note: The first dimension of the dummy_input should be the batchsize. + The dummy input for ```jit.trace```, users should put it on the right + device. masks_file : str The path of user provided mask file map_location : str the device on which masks are placed, same to map_location in ```torch.load``` + batch_dim : int + the index of batch dimension in the dummy_input + confidence: the confidence coefficient of the sparsity inference. This value is + actually used as the batchsize of the dummy_input. """ - from nni.common.graph_utils import build_module_graph - + assert confidence > 1 + # The auto inference will change the values of the parameters in the model + # so we need make a copy before the mask inference + self.ori_state_dict = copy.deepcopy(model.state_dict()) self.bound_model = model - self.masks = torch.load(masks_file, map_location) - self.inferred_masks = dict() # key: module_name, value: ModuleMasks - self.dummy_input = dummy_input - self.torch_graph = build_module_graph(model, dummy_input) + self.inferred_masks = dict() # key: module_name, value: ModuleMasks + self.batch_dim = batch_dim + self.dummy_input, self.device = self._random_model_input(dummy_input, confidence, batch_dim) + self.torch_graph = build_module_graph(model, self.dummy_input) + # dict object to save the auto inferences objects of the submodules + self.auto_inferences = {} + # the index dict to find the corresponding torch._C.Value object + # according to the debug name + # we need the dummy_input to infer the mask automaticlly, so we save + # the indexes from tensor's debugname to the torch._C.Value object. + self.debugname_to_value = {} + # load the mask tensor to the same device with the dummy_input + # self.masks save the mask tensors pruned by the user and the infered + # masks of the others modules + self.masks = torch.load( + masks_file, map_location if map_location is not None else str(self.device)) - def infer_module_mask(self, module_name, last_module, mask=None, in_shape=None, out_shape=None): + self.constant = {} + # self.internal_result save the internal output of the submodules + self.internal_result = {} + + def _random_model_input(self, dummy_input, confidence, batch_dim): """ - Infer input shape / output shape based on the module's weight mask / input shape / output shape. + Get the new random dummy input accordint to the original dummy_input + and confidence, batch_dim. - For a module: - Infer its input and output shape from its weight mask - Infer its output shape from its input shape - Infer its input shape from its output shape + Parameters + ---------- + dummy_input: Tensor or list/dict of Tensors + The dummy_input given by the user. + confidence: int + The new batch size of the generated dummy_input. + batch_dim: int + The index of the batch dimension. + + Returns + ------ + new_dummy_input: Tensor or list/dict of Tensors + The generated dummy_input for mask inference. + device: torch.device + The device of the generated dummy_inputs + """ + input_errmsg = 'Only support the tensor, list/tuple/dict of tensors as input' + # Some model may use list of tensors as input, for example transformers + new_dummy_input, device = None, None + if isinstance(dummy_input, torch.Tensor): + input_shape = list(dummy_input.size()) + # set the batchsize to the confidence ratio + input_shape[batch_dim] = confidence + new_dummy_input = rand_like_with_shape(input_shape, dummy_input) + device = dummy_input.device + elif isinstance(dummy_input, (tuple, list)): + # else if the dummy input is list/tuple + new_dummy_input = [] + old_batchsize = dummy_input[0].size(0) + device = dummy_input[0].device + for _, t_input in enumerate(dummy_input): + assert isinstance(t_input, torch.Tensor), input_errmsg + assert t_input.size(0) == old_batchsize, 'The first dimension should be batchsize\ + and the batchsize of all inputs should be the same!' + input_shape = list(t_input.size()) + input_shape[batch_dim] = confidence + # rand_func = torch.randint if t_input.dtype + new_dummy_input.append( + rand_like_with_shape(input_shape, t_input)) + elif isinstance(dummy_input, dict): + new_dummy_input = {} + tmp_key = list(dummy_input.keys())[0] + old_batchsize = dummy_input[tmp_key].size(0) + device = dummy_input[tmp_key].device + for in_name, t_input in dummy_input.items(): + assert isinstance(t_input, torch.Tensor), input_errmsg + assert old_batchsize == t_input.size(0), 'The first dimension should be batchsize\ + and the batchsize of all inputs should be the same!' + input_shape = list(t_input.size()) + input_shape[batch_dim] = confidence + new_dummy_input[in_name] = rand_like_with_shape( + input_shape, t_input) + else: + raise TypeError(input_errmsg) + return new_dummy_input, device - If its input shape is changed, continue infering its predecessors - If its output shape is changed, continue infering its successors + def _prepare_dummy_input(self, node): + """ + Prepare the dummy_input for the auto mask inference. Parameters ---------- - module_name : str - The name of the node - last_module : str - The name of last visited node - mask : tensor of mask or ModuleMasks - Mask of the weights in this node (i.e., module) - in_shape : ModuleMasks - Input shape of this node - out_shape : ModuleMasks - Output shape of this node - """ - input_cmask = output_cmask = None - if module_name in self.inferred_masks: - module_masks = self.inferred_masks[module_name] + node: NodePyGroup + + Returns + ------- + dummy_input: list + List of tensors that will be used as input for the target node. + debugnames: list of strs + Debugnames of the dummy_inputs. + """ + _logger.debug('Prepare auto mask inference for node: %s', + node.unique_name) + + # prepare the inputs and outputs mask for this node, + # if there is already a mask in self.masks, then use + # the original mask tensor, else create a new one. + inputs_name = node.inputs + # build the dummy_input, in_masks the target node + dummy_input = [] + debugnames = [] + for _input in inputs_name: + if _input not in self.internal_result: + # if the input debug name is not in self.internal_result, + # then this node isn't a output tensor of any predecessor + # nodes. This node is a attribute of the submodule, such as + # weight or bias, etc. We will skip these tensors. + # If we don't want this specific judgement here, we can merge + # the `prim::GetAttr` node of the weight/bias tensor into the key + # node, such as `conv`. + # This is caused by the `meage_module_node` function in the + # _graph_utils.py, because it doesn't merge the prim::GetAttr + # node into the key node. In current version of _graph_utils.py, + # we will only merge the nodes that have same scope name, however, + # the scope name of the correponding prim::GetAttr node of `weight` tensor + # is None. + continue + # The detach operation here is for the in-place operation. We cannot + # directly can the backward on the output tensor of an in-place operator. + dummy_input.append(self.internal_result[_input].detach()) + debugnames.append(_input) + + return dummy_input, debugnames + + def update_direct_sparsity(self, node): + """ + Update the direct sparsity for the target node. Here the direct sparsity + means that the sparsity in the output tensor that caused by the sparsity + in the input tensors/weight tensors. + """ + # this name is consistent with the name returned by named_modules() + module_name = node.name + _logger.info('Update mask for %s', module_name) + unique_name = node.unique_name + dummy_input, input_debugname = self._prepare_dummy_input(node) + # get the input mask from self.masks + # Note: the input mask of the successor nodes are + # already created by the predecessor node + in_masks = [self.masks[debugname] for debugname in input_debugname] + in_constants = [self.constant[debugname] + for debugname in input_debugname] + if node.type == 'func': + # we cannot get the runable function directly from the jit traced + # graph, so we translate it back to python function, Note: the function + # is appliable to both cpu/gpu devices, the output tensors will be on the + # same device of the input tensors + func = jit_to_python_function(node, self) + if func is None: + # no need to infer the sparsity for this node + self.auto_inferences[unique_name] = None + return + # function doesn't have weights + _auto_infer = AutoMaskInference( + func, dummy_input, in_masks, in_constants=in_constants, batch_dim=self.batch_dim) else: - _, m = get_module_by_name(self.bound_model, module_name) - module_masks = ModuleMasks(module_name, m) - self.inferred_masks[module_name] = module_masks - - m_type = self.torch_graph.name_to_node[module_name].op_type - _logger.debug("infer mask of module %s with op_type %s", module_name, m_type) - if mask is not None: - _logger.debug("mask is not None") - if not m_type in infer_from_mask: - raise RuntimeError( - "Has not supported infering input/output shape from mask for module/function: `{}`, {}" - .format(m_type, module_name)) - if m_type in ['Linear']: - input_cmask, output_cmask = infer_from_mask[m_type]( - module_masks, mask, self.torch_graph.name_to_node[module_name].auxiliary - ) - else: - input_cmask, output_cmask = infer_from_mask[m_type](module_masks, mask) - if in_shape is not None: - _logger.debug("in_shape is not None") - if not m_type in infer_from_inshape: - raise RuntimeError( - "Has not supported infering output shape from input shape for module/function: `{}`, {}" - .format(m_type, module_name)) - if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']: - output_cmask = infer_from_inshape[m_type](module_masks, - in_shape, - self.torch_graph.name_to_node[module_name].auxiliary) - elif m_type in ['aten::cat']: - # To calculate the mask for concat operation, the output shape - # , cat dimension, and the order of the input parameters. - output_cmask = infer_from_inshape[m_type](module_masks, - in_shape, - self.torch_graph.name_to_node[module_name].auxiliary, - last_module) - else: - output_cmask = infer_from_inshape[m_type](module_masks, in_shape) - if out_shape is not None: - _logger.debug("out_shape is not None") - if not m_type in infer_from_outshape: - raise RuntimeError( - "Has not supported infering input shape from output shape for module/function: `{}`, {}" - .format(m_type, module_name)) - if m_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']: - input_cmask = infer_from_outshape[m_type](module_masks, out_shape, self.torch_graph.name_to_node[module_name].auxiliary) - else: - input_cmask = infer_from_outshape[m_type](module_masks, out_shape) + weight_mask = None + if module_name in self.masks: + weight_mask = self.masks[module_name] + _, module = get_module_by_name(self.bound_model, module_name) + _auto_infer = AutoMaskInference( + module, dummy_input, in_masks, weight_mask, in_constants=in_constants, + state_dict=copy.deepcopy(module.state_dict()), batch_dim=self.batch_dim) + self.auto_inferences[unique_name] = _auto_infer + _auto_infer.name = node.unique_name + + _auto_infer.update_direct_sparsity() + # also save the input debug names into the auto_infer + _auto_infer.input_debugname = input_debugname + # update the mask tensor and the internal output of the submodules + # after manually unpack the tuple/list of tensors, the number of the outputs + # of each node should always be one(Except for the TupleUnpack node at the end + # of the whole model) + assert len( + node.outputs) == 1, 'The number of the output should be one after the Tuple unpacked manually' + + out_debugname = node.outputs[0] + # update the output mask into self.masks + self.masks[out_debugname] = _auto_infer.output_mask + self.constant[out_debugname] = _auto_infer.out_constant + # update the output result into self.internal_result, so that + # the successor nodes can take these output tensors as inputs. + self.internal_result[out_debugname] = _auto_infer.output + # update the parameter mask of the node - if input_cmask: - predecessors = self.torch_graph.find_predecessors(module_name) - for _module_name in predecessors: - self.infer_module_mask(_module_name, module_name, out_shape=input_cmask) - if output_cmask: - successors = self.torch_graph.find_successors(module_name) - for _module_name in successors: - self.infer_module_mask(_module_name, module_name, in_shape=output_cmask) + self.masks[module_name] = _auto_infer.weight_mask + + def update_indirect_sparsity(self, node): + """ + This function will update the indirect sparsity. To explain what's + indirect sparsity, for example, there is two tensors TA and TB, and + we perform the calculation: TC = TA x TB in which TC is also a tensor. + Once some values in TA are masked to zeros, then the corresponding + positions in TB are also potential sparsities, because these have no + effect of the final output(the gradient of these positions in TB equal + to 0 all the time). This function it to fine the potential sparsity caused + by other sparsity(we call it indirect sparsity here). Basically we can find + these potential sparsity through gradient. + + Parameters + --------- + node: the NodePy + The target node to update the indirect sparsity + """ + module_name = node.name + _logger.info('Update indirect sparsity for %s', module_name) + unique_name = node.unique_name + if unique_name in self.auto_inferences and self.auto_inferences[unique_name] is not None: + # if the auto inference object already in self.auto_inference, then + # directly update the previous one + # self.auto_inferences[unique_name].update() + _logger.info( + 'Update the indirect sparsity for the %s', unique_name) + auto_infer = self.auto_inferences[unique_name] + auto_infer.update_indirect_sparsity() + # pass the gradient to the predecessor nodes + for in_id, tin in enumerate(auto_infer.dummy_input): + debug_name = auto_infer.input_debugname[in_id] + last_output = self.internal_result[debug_name] + # if isinstance(last_output, torch.Tensor): + # TODO what if last output is tuple/list of tensor + if last_output.grad is not None and tin.grad is not None: + last_output.grad.data += tin.grad.data + else: + last_output.grad = tin.grad + else: + _logger.warning('Note: %s does not have corresponding mask inference object', node.name) + + def _vnode_to_value(self, c_node): + """ + translate the C Value node into the values/tensors. + """ + errmsg = "Only support the torch._C.Value type" + assert isinstance(c_node, torch._C.Value), errmsg + if isinstance(c_node.type(), torch._C.TensorType): + shape = tuple(c_node.type().sizes()) + dtype = c_node.type().scalarType() + # TODO should use a more general way to get the input + if dtype.startswith('Float') or dtype.startswith('Double'): + return torch.rand(shape).to(self.device) + else: + # This small range is due to the ·ReLU6·, we will add + # the manual specific mask inference rule for several + # ops in the future, so that we can remove the constraint. + return torch.randint(0, 10, shape, device=self.device) + else: + value = c_node.toIValue() + # TODO support more kinds of value node + errmsg = "Doesn't support convert %s to values", str(c_node.type()) + # currently only support the tensors and constant values + assert value is not None, errmsg + return value def infer_modules_masks(self): """ - Do shape inference of involved modules, including the shape of weights, inputs, output - """ - for module_name, mask in self.masks.items(): - _logger.debug('Start mask inference from %s', module_name) - if module_name not in self.torch_graph.name_to_node: - # this module is not traced in the torch_graph, - # jit.trace only correctly records functions and - # modules which are not data dependent (e.g., do - # not have conditionals on data in tensors) - # so, if a node is not traced, we just skip it. - _logger.warning('%s has mask, but not found in the traced graph, just skip it.', module_name) - continue - self.infer_module_mask(module_name, None, mask=mask) + Infer the mask for all layers in the module, this function can be divided into + two steps: first, forward inference of the the masks. Second, backward inference + of the mask. We keep repeating these two steps until the masks of the model doesn't + change. + """ + # unpack the tensor tuple/list before the mask inference + self.torch_graph.unpack_manually() + # find the input/ouput tensor of the whole graph + graph_input = [] + graph_output = [] + for name, nodeio in self.torch_graph.nodes_py.nodes_io.items(): + if nodeio.input_or_output == 'input': + graph_input.append((name, nodeio)) + # also put the graph input tensor into the internal_result + # TODO if we can find the corresponding relation between the value node + # and the dummy_inputs, we can use the inputs value in the dummy_input + value = self._vnode_to_value(self.debugname_to_value[name]) + self.internal_result[name] = value + # create the mask tensor for the input value + if isinstance(self.internal_result[name], torch.Tensor): + self.masks[name] = torch.ones_like(value) + self.constant[name] = torch.zeros_like(value) + elif nodeio.input_or_output == 'output': + graph_output.append((name, nodeio)) + # count the degree for the node in the graph + in_degree = {} + out_degree = {} + visit_queue = queue.Queue() + for node in self.torch_graph.nodes_py.nodes_op: + successors = self.torch_graph.find_successors(node.unique_name) + out_degree[node.unique_name] = len(successors) + predecessors = self.torch_graph.find_predecessors(node.unique_name) + in_degree[node.unique_name] = len(predecessors) + if in_degree[node.unique_name] == 0: + visit_queue.put(node) + # Forward mask inference + while not visit_queue.empty(): + curnode = visit_queue.get() + # forward mask inference for curnode + self.update_direct_sparsity(curnode) + successors = self.torch_graph.find_successors(curnode.unique_name) + for successor in successors: + in_degree[successor] -= 1 + if in_degree[successor] == 0: + visit_queue.put(self.torch_graph.name_to_node[successor]) + # backward mask inference + for unique_name in out_degree: + if out_degree[unique_name] == 0: + visit_queue.put(self.torch_graph.name_to_node[unique_name]) + while not visit_queue.empty(): + curnode = visit_queue.get() + self.update_indirect_sparsity(curnode) + predecessors = self.torch_graph.find_predecessors( + curnode.unique_name) + for predecessor in predecessors: + out_degree[predecessor] -= 1 + if out_degree[predecessor] == 0: + visit_queue.put(self.torch_graph.name_to_node[predecessor]) def replace_compressed_modules(self): """ @@ -148,40 +377,138 @@ def replace_compressed_modules(self): NOTE: ```func``` type cannot be replaced as it is not a module, thus, one limitation is that ```func``` should be not required to be replaced. """ - for module_name in self.inferred_masks: - g_node = self.torch_graph.name_to_node[module_name] - _logger.debug("replace %s, in %s type, with op_type %s", - module_name, g_node.type, g_node.op_type) - if g_node.type == 'module': - super_module, leaf_module = get_module_by_name(self.bound_model, g_node.name) - m_type = g_node.op_type - if not m_type in replace_module: - raise RuntimeError("Has not supported replacing the module: `{}`".format(m_type)) - _logger.info("replace module (name: %s, op_type: %s)", g_node.name, m_type) - compressed_module = replace_module[m_type](leaf_module, self.inferred_masks[module_name]) - setattr(super_module, g_node.name.split('.')[-1], compressed_module) - elif g_node.type == 'func': - _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type", - module_name, g_node.op_type) - else: - raise RuntimeError("Unsupported node type: {}".format(g_node.type)) + with torch.no_grad(): + for unique_name in self.auto_inferences: + self.replace_submodule(unique_name) + + def replace_submodule(self, unique_name, reindex_dim=None, reindex=None): + """ + Replace the submodule according to the inferred sparsity. + unique_name: str + The unique_name of the submodule to replace. + reindex_dim: int + The dimension of the re-index operation. + reindex: Reindex + The index tensor. Normally this variable is None. If we want to reindex the + output of this submodule, we can pass the index by this parameter. + """ + class ReindexModule(nn.Module): + """ + ReindexModule is used to resolve the mask conflict when replace the submodule. + Basically, we can use two ways to resolve the mask conflict: (1) unmask some + values(will introduce more computation overhead) (2) reindex and padd the output + tensor of the target op(introduce more memory access overhad). Currently this + method is shutdown, in the future, we will merge these two methods into a graph + pass which is used to resolve the mask conflict. + """ + def __init__(self, ori_module, reindex_dim, reindex): + super(ReindexModule, self).__init__() + self.ori_module = ori_module + self.reindex_dim = reindex_dim + self.reindex = reindex + tmp_index = [slice(None, None) for i in range(reindex_dim+1)] + # the index for the tensor + tmp_index[reindex_dim] = reindex + self.t_index = tuple(tmp_index) + + def forward(self, x): + tmpout = self.ori_module(x) + shape = list(tmpout.size()) + shape[self.reindex_dim] = self.reindex.size(0) + out = torch.zeros(tuple(shape), device=tmpout.device, + requires_grad=tmpout.requires_grad) + out[self.t_index] = tmpout + return out + + assert unique_name in self.auto_inferences + g_node = self.torch_graph.name_to_node[unique_name] + _logger.debug("replace %s, in %s type, with op_type %s", + unique_name, g_node.type, g_node.op_type) + auto_infer = self.auto_inferences[unique_name] + if g_node.type == 'module': + if g_node.unique_name in self.torch_graph.reused_module: + if reindex_dim is not None: + _logger.warning( + 'Cannot replace a reused module with padding operator!!') + return None + super_module, leaf_module = get_module_by_name( + self.bound_model, g_node.name) + m_type = g_node.op_type + if not m_type in replace_module: + raise RuntimeError( + "Has not supported replacing the module: `{}`".format(m_type)) + _logger.info("replace module (name: %s, op_type: %s)", + g_node.name, m_type) + compressed_module = replace_module[m_type]( + leaf_module, auto_infer.get_masks()) + new_submodule = compressed_module + if reindex_dim is None: + setattr(super_module, g_node.name.split( + '.')[-1], compressed_module) + elif reindex_dim is not None and reindex is not None: + # reindex the output of this submodule and replace the orginal module + new_submodule = ReindexModule( + compressed_module, reindex_dim, reindex) + setattr(super_module, g_node.name.split( + '.')[-1], new_submodule) + return new_submodule + elif g_node.type == 'func': + _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type", + unique_name, g_node.op_type) + return None + else: + raise RuntimeError("Unsupported node type: {}".format(g_node.type)) + + def initialize_speedup(self): + """ + Do some initial work for speedup. + """ + # initialize the self.debugname_to_value + # build a mapping table from the debug name of the tensor + # to its value node in the graph + traced_graph = self.torch_graph.trace.graph + for node in traced_graph.nodes(): + for _input in node.inputs(): + debug_name = _input.debugName() + if debug_name not in self.debugname_to_value: + self.debugname_to_value[debug_name] = _input + for _output in node.outputs(): + debug_name = _output.debugName() + if debug_name not in self.debugname_to_value: + self.debugname_to_value[debug_name] = _output + # put the model itself into internel_result to perform the + # value inference for the 'prim::GetAttr', the first ClassType + # of the whole graph is the model class + + for graph_input in traced_graph.inputs(): + if graph_input.type().kind() == 'ClassType': + self.internal_result[graph_input.debugName() + ] = self.bound_model + break def speedup_model(self): """ - There are basically two steps: - first, do mask/shape inference, - second, replace modules + There are basically two steps: first, do mask/shape inference, + second, replace modules. """ - training = self.bound_model.training - _logger.info("start to speed up the model") - _logger.info("fix the mask conflict of the interdependent layers") - _, conv_prune_dim = fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) - set_conv_prune_dim(conv_prune_dim) + _logger.info("start to speed up the model") + self.initialize_speedup() + training = self.bound_model.training + # set to the evaluation mode + self.bound_model.train(False) + # TODO suppose to fix the conflict after the sparsity propagation + # which is more elegent + fix_mask_conflict(self.masks, self.bound_model, self.dummy_input) _logger.info("infer module masks...") self.infer_modules_masks() + _logger.info('resolve the mask conflict') + + # load the original stat dict before replace the model + self.bound_model.load_state_dict(self.ori_state_dict) _logger.info("replace compressed modules...") + # the mask conflict should be already resolved self.replace_compressed_modules() self.bound_model.train(training) _logger.info("speedup done") diff --git a/nni/compression/pytorch/speedup/infer_mask.py b/nni/compression/pytorch/speedup/infer_mask.py new file mode 100644 index 0000000000..cdede1b675 --- /dev/null +++ b/nni/compression/pytorch/speedup/infer_mask.py @@ -0,0 +1,378 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import torch +import torch.nn as nn +from ..utils import randomize_tensor, torch_float_dtype, torch_integer_dtype +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +STD_DELTA = 1e-6 + + +class AutoMaskInference: + def __init__(self, module, dummy_input, in_masks=None, weight_mask=None, \ + output_mask=None, name=None, in_constants=None, state_dict=None, batch_dim=0): + """ + This class will infer the mask of the target module automatically. + This update_direct_sparsity will infer the output mask according + to the input masks, in constrast, update_indirect_sparsity will + infer the input masks according to given output masks. The newly + found sparsity will be incrementally updated to the original in_masks + and output_mask. + + Parameters + ---------- + module: torch.nn.Module/function + The target module to infer the mask. Need to be callable. + dummy_input: torch.Tensor/list of Tensor + The dummy_input of the target module. + in_masks: list of torch.Tensor + The input masks of the target module, if in_masks is not None, then + update_direct_sparsity and update_indirect_sparsity will incrementally + update the given in_masks, else, AutoMaskInference will create a new + in_masks for the target module. + output_mask: torch.Tensor + The output mask of the target module. Similar to in_masks, if output_mask + is not None, then update_direct_sparsity and update_indirect_sparsity will + incrementally update the given output_mask, else AutoMaskInference will create + one output_mask for the target module. + weight_mask: dict of the weight masks + The weight masks of the target module, the key is the corresponding name of + the mask. For example: {'weight':torch.ones(1000, 1000), bias:torch.ones(1000)} + name: str + Name of the target module. + in_constants: list of torch.Tensor + The correponding constant values of the in_masks. + state_dict: dict of torch.Tensor + The original values of the weights. + batch_dim: int + The index of the batch dimension of the input tensors. + + """ + errmsg = '%s is not callable, should pass the nn.Module/function' % str( + module) + assert callable(module), errmsg + self.module = module + + # Initialize the dummy_input + if isinstance(dummy_input, list): + # if there are multiple input variables + self.dummy_input = dummy_input + else: + # if there is only one input variable + self.dummy_input = [dummy_input] + + # Initialize the masks for input tensors + self.in_masks = in_masks if in_masks is not None else [ + None] * len(self.dummy_input) + self.in_constants = in_constants if in_constants is not None else [ + torch.zeros_like(x) for x in dummy_input] + for in_id, _ in enumerate(self.in_masks): + if self.in_masks[in_id] is None and \ + isinstance(self.dummy_input[in_id], torch.Tensor): + # if the input mask is None then create a all-ones mask for corresponding input tensor + self.in_masks[in_id] = torch.ones_like(self.dummy_input[in_id]) + # ones_like will put the created mask on the same device with the dummy_input + + # Initialize the mask for output tensors + self.output = self.module(*dummy_input) + # self.output.requires_grad_() + if output_mask is not None: + # assume the given output mask is right + self.output_mask = output_mask + else: + if isinstance(self.output, torch.Tensor): + self.output_mask = torch.ones_like(self.output) + elif isinstance(self.output, list) or isinstance(self.output, tuple): + self.output_mask = [] + for o_tensor in self.output: + if isinstance(o_tensor, torch.Tensor): + self.output_mask.append(torch.ones_like(o_tensor)) + else: + # if one of the outputs is not tensor, set the corresponding + # mask to None + self.output_mask.append(None) + else: + self.output_mask = None + + # Initialize the mask for the parameters + self.weights = {} + self.weight_mask = {} + if weight_mask: + self.weight_mask.update(weight_mask) + if isinstance(self.module, nn.Module): + # the function should not has parameters + # get all the parameter tensors of the target module + for name, para in module.named_parameters(): + self.weights[name] = para + if name not in self.weight_mask: + self.weight_mask[name] = torch.ones_like(para.data) + self.name = name + self.state_dict = state_dict + # TODO support the other batch dimension in the future + self.batch_dim = batch_dim + + def random_init(self, start=0.1, end=8.0): + """ + Random initialize the weights of the module. The value of + the tensor will not affect the mask auto inference. + """ + # currently we set the random range to 0.1-8.0 because of the ReLU6, + # if we use a range that far larger than 6, it may infer a wrong mask + # when the confidence is low. In the future, we will add the mask inference + # rules for ReLU6 to break this range constraint. + with torch.no_grad(): + for tensor in self.dummy_input: + if isinstance(tensor, torch.Tensor) and len(tensor.size()) > 0: + # if the tensor is a scalar, then skip this tensor + randomize_tensor(tensor, start, end) + for para in self.weights: + randomize_tensor(self.weights[para].data, start, end) + + + def zero_grad(self): + """ + Set the gradient of the weight, input tensor to be zeros. + """ + with torch.no_grad(): + # set the weight's gradient to zero + if isinstance(self.module, nn.Module): + self.module.zero_grad() + # also zero the gradient of the input tensors + for tensor in self.dummy_input: + if isinstance(tensor, torch.Tensor): + if tensor.grad is not None: + tensor.grad.data.zero_() + + def requires_grad_(self, flag=True): + """ + Set the requires_grad of input tensor and parameters to flag. + """ + for t_in in self.dummy_input: + if isinstance(t_in, torch.Tensor) and t_in.dtype in torch_float_dtype: + # only float type can require the gradient + # enable the auto gradient + t_in.requires_grad_(flag) + for para_name in self.weights: + if self.weights[para_name].dtype in torch_float_dtype: + self.weights[para_name].requires_grad_(flag) + + def apply_mask(self): + self.__apply_input_mask() + self.__apply_weight_mask() + + def __apply_input_mask(self): + """ + Apply the mask of the input tensor. + """ + with torch.no_grad(): + # apply the input mask + for tid, in_tensor in enumerate(self.dummy_input): + if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None: + in_tensor.data = in_tensor.data * \ + self.in_masks[tid] + \ + (1-self.in_masks[tid]) * self.in_constants[tid] + + + def __apply_weight_mask(self): + """ + Apply the weight mask of this module. + """ + with torch.no_grad(): + # apply the weight mask + for para in self.weights: + if para in self.weight_mask: + self.weights[para].data *= self.weight_mask[para].data + + def isconstants(self, tout): + """ + Find the constants in the tensor tout. This function return a mask tensor that + indicates if a value in tout is a constant, and return one more tensor to indicate + that the values of the constant. + + Paramters + --------- + tout: torch.Tensor + The target output tensor to find the constants + Returns + ------- + mask: torch.Tensor + The mask tensor(same shape with tout) that indicates that whether + the correponding value is a constant. + constant: torch.Tensor + The mask tensot(same shape with tout) that indicates the values of + the constants in the tout. + """ + assert isinstance(tout, torch.Tensor) + out_mask = torch.ones_like(tout) + constant = torch.zeros_like(tout) + # judge if tout is a scalar(tensor that only have one value) + if len(tout.size()) == 0: + # tout is a scalar tensor, for the scalar tensor, we take + # this scalar as a constant, usually, the scalar tensor is returned + # by the size() function + constant = tout + return out_mask, constant + if tout.dtype in torch_integer_dtype: + # Pytorch cannot use torch.mean and torch.std to process + # intergers :( , so if dtype of the input tensor is integer, we need + # check if is the constant by ourselves + # Note: the first dimension should be the batch dimension + same = tout[:] == tout[0] + reduced = torch.sum(same, dim=0) + is_constant = reduced == tout.size(0) + out_mask[:, is_constant] = 0 + constant[:, is_constant] = tout[0][is_constant] + + else: + # calculate the std of the output among batch dimension + std = torch.std(tout, dim=0) + # calculate the mean value of the output among the batch dimension + mean = torch.mean(tout, dim=0) + mask_pos = std < STD_DELTA + out_mask[:, mask_pos] = 0 + constant[:, mask_pos] = mean[mask_pos] + return out_mask, constant + + + def update_indirect_sparsity(self): + """ + This function will update the indirect sparsity. To explain what's + indirect sparsity, for example, there is two tensors TA and TB, and + we perform the calculation: TC = TA x TB in which TC is also a tensor. + Once some values in TA are masked to zeros, then the corresponding + positions in TB are also potential sparsities, because these have no + effect of the final output(the gradient of these positions in TB equal + to 0 all the time). This function it to fine the potential sparsity caused + by other sparsity(we call it indirect sparsity here). Basically we can find + these potential sparsity through gradient. + """ + # Each node only update the output mask when we backwards + # update the output mask, this is because that some op may + # have the broadcast operation, for example, OP A's output + # tensor may be taken by two OPs(B, C) as inputs. So we cannot + # directly update the input mask at the OP B or C. We can only + # update the mask of C's output tensor only when B and C are + # already updated(gradient are already calculated and added to + # C's output tensor). + # Besides, updating the mask of C's output tensor equals to updating + # the input mask of OP B and C. + if isinstance(self.output, torch.Tensor) and self.output.grad is not None: + # if output have gradient which means this node has successor + # nodes and the successor nodes have already update their indirect + # sparsity + # we can mask the values whose gradient is always zeros + gradient_sum = torch.sum(torch.abs(self.output.grad.data), dim=0) + _grad_zero = gradient_sum == 0 + for batchid in range(self.output.size(0)): + # set the same mask value for the whole batche + self.output_mask[batchid][_grad_zero] = 0 + elif isinstance(self.output, tuple) or isinstance(self.output, list): + assert isinstance(self.output_mask, (tuple, list)) + for oid, tout in enumerate(self.output): + errmsg = 'The output only support tensor/list of tensors' + assert isinstance(tout, torch.Tensor), errmsg + gradient_sum = torch.sum( + torch.abs(self.output.grad.data), dim=0) + _grad_zero = gradient_sum == 0 + for batchid in range(self.output.size(0)): + # set the same mask value for the whole batch + self.output_mask[oid][batchid][_grad_zero] = 0 + + self.requires_grad_(True) + # Forward inference with auto gradient enabled + # Note: tensors that need gradient cannot be used in the in-place operator + self.random_init() + self.apply_mask() + # Some operator may have the in_place operations, so we need to clone the input + # before passing to the self.module + tmp_dummy_input = [x.clone() if isinstance( + x, torch.Tensor) else x for x in self.dummy_input] + output = self.module(*tmp_dummy_input) + + if output.grad_fn is None: + # the output does not have the gradient function + return + # Note: output maybe tensor or list/tuple of tensors + if isinstance(output, torch.Tensor): + output.backward(self.output_mask) + elif isinstance(output, list) or isinstance(output, tuple): + for tid, t_out in enumerate(output): + t_out.backward(self.output_mask[tid]) + + # update the sparsity of the paramters + for para_name in self.weights: + grad_zero = self.weights[para_name].grad.data == 0 + self.weight_mask[para_name][grad_zero] = 0 + + def update_direct_sparsity(self): + # we don't need the gradient in the forward inference + out_mask = None + constant = None + with torch.no_grad(): + # Note: we need randomly init the input one more time here! + # Because some operation have the in-place operation, such as relu_, + # the in-place operation may modify or write 0s into the dummy_input + self.random_init() + # apply the mask for the input tensor and the weight tensor + self.apply_mask() + # Note: due to the in-place operator, such as relu_, + # ori_out may be the same tensor with dummy_input, + # so we use clone and detach to create a new tensor with + # the same values. + out = self.module(*self.dummy_input) + if isinstance(out, torch.Tensor): + out_mask, constant = self.isconstants(out.clone().detach()) + elif isinstance(out, tuple) or isinstance(out, list): + out_mask = [] + constant = [] + for tout in out: + _mask, _constant = self.isconstants(tout.clone().detach()) + out_mask.append(_mask) + constant.append(_constant) + else: + _logger.warning( + 'Only support the OP whose output is tensor/tuple of tensor/list of tensor') + + # We also need random the parameters of the module, because if the weight of the model has + # a unmasked 0, then our out sparsity inference may be wrong + # However, after radomizing the weight/parameters, the constant in the output tensors may + # be different from the constants that calculated from its original stata_dict. However, + # so to get the right constant to eliminate the bias between model before and after sparsity + # inference, we need to reload its state_dict and recalculate the constant + # Currently we also get the constant values at the same time when infering the mask, in + # the future, we will separate the constant inference into a single graph pass. + if len(self.weights) > 0 and self.state_dict is not None: + + self.module.load_state_dict(self.state_dict) + # apply weight mask + self.__apply_weight_mask() + out = self.module(*self.dummy_input).clone().detach() + if isinstance(out, torch.Tensor): + constant = torch.zeros_like(out) + constant_pos = out_mask == 0 + constant[constant_pos] = out[constant_pos] + elif isinstance(out, (list, tuple)): + constant = [] + for i, tout in enumerate(out): + _tmp = torch.zeros_like(tout) + sparsity_pos = out_mask[i] == 0 + _tmp[sparsity_pos] = tout[sparsity_pos] + constant.append(_tmp) + + if isinstance(out_mask, torch.Tensor): + assert isinstance(self.output_mask, torch.Tensor) + self.output_mask *= out_mask + elif isinstance(out_mask, list): + for i, _ in enumerate(out_mask): + self.output_mask[i] *= out_mask[i] + else: + _logger.warning('There is no output sparsity') + # also save the out_constant + self.out_constant = constant + + def get_masks(self): + return (self.in_masks, self.output_mask, self.weight_mask) + diff --git a/nni/compression/pytorch/speedup/infer_shape.py b/nni/compression/pytorch/speedup/infer_shape.py deleted file mode 100644 index 693ef32b48..0000000000 --- a/nni/compression/pytorch/speedup/infer_shape.py +++ /dev/null @@ -1,1146 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. -""" -For each operation or module, there are two functions. -One is given output shape, infer its input shape and initialization parameters (e.g., weight's shape) -The other is given input shape, infer its output shape and initialization parameters (e.g., weight's shape) -""" - -import logging -import torch - -_logger = logging.getLogger(__name__) - -conv_prune_dim = -1 - - -def set_conv_prune_dim(dim): - """ - Parameters: - dim: int - 0: filter pruning - 1: channel pruning - """ - global conv_prune_dim - conv_prune_dim = dim - - -class CoarseMask: - """ - Coarse grained mask for a given tensor, here tensor could be weights, - input tensor, or output tensor - """ - - def __init__(self, num_dim): - """ - Parameters - ---------- - num_dim : int - The number of dimensions of the tensor that will be masked - """ - self.mask_index = [None for _ in range(num_dim)] - - def add_index_mask(self, dim, index): - """ - Add mask for the specified dimension - - Parameters - ---------- - dim : int - The dimension to add mask - index : tensor - The mask for this dimension, its a 1 dimension tensor which specifies - the index of the elements that are not pruned - """ - self.mask_index[dim] = index - - @staticmethod - def merge_index(index_a, index_b): - """ - Parameters - ---------- - index_a : tensor - One index (1-dimension) tensor - index_b : tensor - The other index (1-dimension) tensor - - Returns - ------- - tensor - The merged index (1-dimension) tensor - Note that: the output tensor will be moved - to the same device as index_a. - """ - device = index_a.device - s = set() - for num in index_a.tolist(): - # we need to transfer the tensor to list here - # first, directly traversing the tensor by for - # loop will return the list of tensor(x) object, - # even the value are the same, but they are different - # tensor objects, so the set will contains multiple - # tensor objects that has the same value. For example - # for num in torch.ones(2): - # s.add(num) - # s will be {tensor(1), tensor(1)} - s.add(num) - for num in index_b.tolist(): - s.add(num) - # move the output tensor to the same device with index_a - return torch.tensor(sorted(s)).to(device) # pylint: disable=not-callable - - def merge(self, cmask): - """ - Merge another CoarseMask - - Parameters - ---------- - cmask : CoarseMask - Another CoarseMask to merge - - Returns - ------- - list - The member variable ```mask_index``` - """ - assert isinstance(cmask, CoarseMask) - assert len(self.mask_index) == len(cmask.mask_index), \ - "Only masks with the same number of dimensions can be merged" - for i, index in enumerate(self.mask_index): - if index is None: - self.mask_index[i] = cmask.mask_index[i] - elif cmask.mask_index[i] is not None: - self.mask_index[i] = CoarseMask.merge_index(self.mask_index[i], - cmask.mask_index[i]) - return self.mask_index - - def __repr__(self): - return 'mask_index: {}'.format(self.mask_index) - - def eq_on_dim(self, other, dim): - assert isinstance(other, CoarseMask) - if self.mask_index[dim] is None and other.mask_index[dim] is None: - return True - elif isinstance(self.mask_index[dim], torch.Tensor) \ - and isinstance(other.mask_index[dim], torch.Tensor): - return torch.equal(self.mask_index[dim], other.mask_index[dim]) - else: - return False - - def __eq__(self, other): - assert isinstance(other, CoarseMask) - if len(self.mask_index) != len(other.mask_index): - return False - for i in range(len(self.mask_index)): - if not self.eq_on_dim(other, i): - return False - return True - - def __lt__(self, other): - """ - Judge if the mask is a subset of another CoarseMask. - """ - assert isinstance(other, CoarseMask) - for dim, _ in enumerate(self.mask_index): - # if self has more dimensions - if dim >= len(other.mask_index): - return False - if self.mask_index[dim] is None: - # if no mask on this dimension, then we have less - # masks then the other CoraseMask. - continue - elif other.mask_index[dim] is None: - return False - else: - s1 = set(self.mask_index[dim].tolist()) - s2 = set(other.mask_index[dim].tolist()) - if not s1 < s2: - return False - return True - - def __le__(self, other): - """ - Return if self's mask is less or equal to other's mask. - """ - assert isinstance(other, CoarseMask) - if self.__lt__(other) or self.__eq__(other): - return True - return False - - def __ne__(self, other): - return not self.__eq__(other) - - -class ModuleMasks: - """ - The masks of a module, including the masks for weights, inputs, output - """ - - def __init__(self, module_name, module=None): - """ - Parameters - ---------- - module_name : str - The name of the module or function - """ - self.module_name = module_name - self.module = module - self.param_masks = dict() - self.input_mask = None - self.output_mask = None - - def set_param_masks(self, name, mask): - """ - Parameters - ---------- - name : str - The name of the weight - mask : CoarseMask - The mask for this weight - """ - self.param_masks[name] = mask - - def set_input_mask(self, mask): - """ - Parameters - ---------- - mask : CoarseMask - The mask for input - """ - self.input_mask = mask - - def set_output_mask(self, mask): - """ - Parameters - ---------- - mask : CoarseMask - The mask for output - """ - self.output_mask = mask - - def __repr__(self): - return 'module_name: {}, input_mask: {}, output_mask: {}, param_masks: {}'.format( - self.module_name, self.input_mask, self.output_mask, self.param_masks - ) - - -""" -Infer input and output shape of a module/function from its weight mask -""" -infer_from_mask = { - 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_mask(module_masks, mask), - 'Conv2d': lambda module_masks, mask: conv2d_mask(module_masks, mask), - 'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_mask(module_masks, mask), - 'Linear': lambda module_masks, mask, shape: linear_mask(module_masks, mask, shape) -} - -""" -Infer output and weight shape of a module/function from its input shape -""" -infer_from_inshape = { - 'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'PReLU': lambda module_masks, mask: prelu_inshape(module_masks, mask), - 'Sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'aten::tanh_': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'aten::hardtanh': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'aten::hardtanh_': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'aten::sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask), - 'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask), - 'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_inshape(module_masks, mask), - 'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), - 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), - 'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), - 'aten::adaptive_avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), - 'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), - 'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), - 'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask), - 'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape), - 'aten::reshape': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape), - # support only start_dim=1 - 'aten::flatten': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape), - 'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask), - 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask), - 'aten::add_': lambda module_masks, mask: add_inshape(module_masks, mask), - 'aten::add': lambda module_mask, mask: add_inshape(module_mask, mask), - # mul has the similar behaviour with add, they both request - # the input tesors to have the same shape - 'aten::mul': lambda module_mask, mask: add_inshape(module_mask, mask), - 'aten::mul_': lambda module_mask, mask: add_inshape(module_mask, mask), - 'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited), - 'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape), - 'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask), - 'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask), - 'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask), - 'aten::detach': lambda module_masks, mask: dropout_inshape(module_masks, mask) -} - -""" -Infer input and weight shape of a module/function from its output shape -""" -infer_from_outshape = { - 'Conv2d': lambda module_masks, mask: conv2d_outshape(module_masks, mask), - 'ConvTranspose2d': lambda module_masks, mask: convtranspose2d_outshape(module_masks, mask), - 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_outshape(module_masks, mask), - - 'MaxPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), - 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), - 'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), - 'aten::adaptive_avg_pool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), - 'AvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), - 'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), - - 'ReLU': lambda module_masks, mask: relu_outshape(module_masks, mask), - 'PReLU': lambda module_masks, mask: prelu_outshape(module_masks, mask), - 'ReLU6': lambda module_masks, mask: relu_outshape(module_masks, mask), - 'aten::relu': lambda module_masks, mask: relu_outshape(module_masks, mask), - 'aten::tanh': lambda module_masks, mask: relu_outshape(module_masks, mask), - 'aten::tanh_': lambda module_masks, mask: relu_outshape(module_masks, mask), - 'aten::hardtanh': lambda module_masks, mask: relu_outshape(module_masks, mask), - 'aten::hardtanh_': lambda module_masks, mask: relu_outshape(module_masks, mask), - 'aten::relu_': lambda module_masks, mask: relu_outshape(module_masks, mask), - - 'aten::add_': lambda module_masks, mask: add_outshape(module_masks, mask), - 'aten::add': lambda module_mask, mask: add_outshape(module_mask, mask), - 'aten::flatten': lambda module_mask, mask, shape: view_outshape(module_mask, mask, shape), - 'aten::view': lambda module_masks, mask, shape: view_outshape(module_masks, mask, shape), - 'aten::reshape': lambda module_masks, mask, shape: view_outshape(module_masks, mask, shape), - 'aten::mean': lambda module_masks, mask, shape: mean_outshape(module_masks, mask, shape), - 'Dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask), - 'Dropout2d': lambda module_masks, mask: dropout_outshape(module_masks, mask), - 'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask), - 'aten::detach': lambda module_masks, mask: dropout_outshape(module_masks, mask) -} - - -def dropout_inshape(module_masks, mask): - if module_masks.input_mask is None: - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - return module_masks.output_mask - # if alreay visited - assert module_masks.input_mask <= mask - # It should be the same, we pass the masks by the reference(not the value), - # so they acutually are two references of the same object(mask, - # module_masks.input_mask). So we should continue pass the mask - # to the following nodes even module_masks.input_mask == mask. - # if pass the mask by copy.deepcopy(), then we can stop when - # module_masks.input_mask == mask. - # if module_masks.input_mask == mask: - # return None - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - return module_masks.output_mask - - -def dropout_outshape(module_masks, mask): - if module_masks.output_mask is None: - module_masks.set_output_mask(mask) - module_masks.set_input_mask(mask) - return module_masks.input_mask - # if alreay visited - assert all(module_masks.output_mask.mask_index[1] == mask.mask_index[1]) - - return module_masks.output_mask - - -def cat_inshape(module_masks, mask, cat_info, last_visited): - """ - Inference the output mask of the cat operation from the - input mask. - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the Conv2d - mask : CoarseMask - The mask of its input tensor - cat_info: dict - Dict object that records the necessary information - of cat operation, such as the order of the input - tensors. - last_visited: str - The unique_name of the last visited node group. - - Returns - ------- - CoarseMask - The mask of its output tensor - - """ - assert isinstance(mask, CoarseMask) - out_shape = cat_info['out_shape'] - cat_dim = cat_info['cat_dim'] - in_order = cat_info['in_order'] - in_shape = cat_info['in_shape'] - if module_masks.output_mask is None: - # First visit to this cat node - # initialize the mask based on - # the number of the output channel. - output_mask = CoarseMask(num_dim=len(out_shape)) - for dim, _ in enumerate(out_shape): - if dim == cat_dim: - if mask.mask_index[dim] is None: - continue - device = mask.mask_index[dim].device - # calculate the offset of the mask - pos = in_order.index(last_visited) - offsets = [in_shape[i][cat_dim] - for i, _ in enumerate(in_shape)] - offset = 0 - for i in range(pos): - offset += offsets[i] - _tmp_mask = (mask.mask_index[dim] + offset).to(device) - output_mask.mask_index[dim] = _tmp_mask - else: - # directly copy the mask - if mask.mask_index[dim] is not None: - output_mask.mask_index[dim] = mask.mask_index[dim].data.clone( - ) - module_masks.set_output_mask(output_mask) - - return module_masks.output_mask - # If this cat node is already visited, we need - # validating if the mask is legel, for cat operation, - # the mask on the 'cat_dim' dimension should be stitched - # together. In the other dimensions, the mask should be - # the same, else the mask is not legal. - for dim, _ in enumerate(out_shape): - if dim == cat_dim: - if mask.mask_index[dim] is None: - continue - pos = in_order.index(last_visited) - offsets = [in_shape[i][cat_dim] for i, _ in enumerate(in_shape)] - offset = 0 - for i in range(pos): - offset += offsets[i] - device = mask.mask_index[dim].device - new_mask = mask.mask_index[dim] + offset - module_masks.output_mask.mask_index[dim] = CoarseMask.merge_index( - module_masks.output_mask.mask_index[dim], new_mask).to(device) - else: - assert module_masks.output_mask.eq_on_dim(mask, dim) - - return module_masks.output_mask - - -def add_inshape(module_masks, mask): - """ - Inference the output mask of the add operation from the - input mask. - """ - assert isinstance(mask, CoarseMask) - if module_masks.input_mask is None: - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - # module_masks.input_mask = mask - return mask - # If alreay visited, validate if have the conflict - # if the mask is different with previous input_mask - # then there is a mask confilct. - if mask != module_masks.input_mask: - raise Exception('Mask conflict happenes!') - return None - - -def add_outshape(module_masks, mask): - """ - Inference the input mask of the add operation from the - output mask. - """ - assert isinstance(mask, CoarseMask) - - if module_masks.output_mask is None: - module_masks.set_output_mask(mask) - module_masks.set_input_mask(mask) - return mask - else: - assert all( - module_masks.output_mask.mask_index[1] == mask.mask_index[1]) - return mask - - -def batchnorm2d_inshape(module_masks, mask): - """ - We assume only the second dimension has coarse grained mask - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the batchnorm2d - mask : CoarseMask - The mask of its input tensor - - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - assert mask.mask_index[2] is None - assert mask.mask_index[3] is None - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - weight_cmask = CoarseMask(num_dim=1) - weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) - module_masks.set_param_masks('weight', weight_cmask) - module_masks.set_param_masks('bias', weight_cmask) - return mask - - -def batchnorm2d_outshape(module_masks, mask): - """ - We assume only the second dimension has coarse grained mask - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the batchnorm2d - mask : CoarseMask - The mask of its input tensor - - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - assert len(mask.mask_index) in [2, 4] - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - weight_cmask = CoarseMask(num_dim=1) - weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) - module_masks.set_param_masks('weight', weight_cmask) - module_masks.set_param_masks('bias', weight_cmask) - return mask - - -def linear_inshape(module_masks, mask): - """ - Coarse grained input mask does not change the shape of weights and output tensor - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the linear - mask : CoarseMask - The mask of its input tensor - - Returns - ------- - CoarseMask - The mask of its output tensor, ```None``` means shape of output tensor is not changed - """ - assert isinstance(mask, CoarseMask) - assert mask.mask_index[0] is None - if module_masks.input_mask is not None: - assert module_masks.input_mask <= mask - module_masks.set_input_mask(mask) - return None - - -def view_inshape(module_masks, mask, shape): - """ - This is a limited support - - TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not - included in module, thus, cannot be replaced by our framework. - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the ```view``` op - mask : CoarseMask - The mask of its input tensor - shape : dict - Original shape of its input and output tensors - - Returns - ------- - CoarseMask - The mask of its output tensor - """ - # NOTE: the case constrained by the following four asserts - assert shape['in_shape'][0] == shape['out_shape'][0] - assert len(shape['in_shape']) == 4 - assert len(shape['out_shape']) == 2 - assert shape['out_shape'][1] == shape['in_shape'][1] * \ - shape['in_shape'][2]*shape['in_shape'][3] - - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - assert mask.mask_index[2] is None - assert mask.mask_index[3] is None - # due to the cat operation, the same node may be - # accessed more than once - if module_masks.input_mask is not None: - assert module_masks.input_mask <= mask - module_masks.set_input_mask(mask) - output_cmask = CoarseMask(num_dim=2) - index = [] - step_size = shape['in_shape'][2] * shape['in_shape'][3] - for loc in mask.mask_index[1]: - index.extend([loc * step_size + i for i in range(step_size)]) - output_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable - module_masks.set_output_mask(output_cmask) - return output_cmask - - -def view_outshape(module_masks, mask, shape): - """ - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the ```view``` op - mask : CoarseMask - The mask of its output tensor - shape : dict - Original shape of its input and output tensors - Returns - ------- - CoarseMask - The mask of its input tensor - """ - # NOTE: the case constrained by the following four asserts - assert shape['in_shape'][0] == shape['out_shape'][0] - assert len(shape['in_shape']) == 4 - assert len(shape['out_shape']) == 2 - assert shape['out_shape'][1] == shape['in_shape'][1] * \ - shape['in_shape'][2]*shape['in_shape'][3] - - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - - module_masks.set_output_mask(mask) - input_cmask = CoarseMask(num_dim=4) - index = set() - step_size = shape['in_shape'][2] * shape['in_shape'][3] - for loc in mask.mask_index[1]: - index.add(loc // step_size) - index = sorted(list(index)) - input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable - module_masks.set_input_mask(input_cmask) - - return input_cmask - - -def size_inshape(module_masks, mask): - """ - No need to do anything for this ```size``` op - """ - return None - - -def mean_inshape(module_masks, mask, shape): - """ - Similar to view operation, currently mask inference only supports - the mean operation on the 3rd and 4th dimensions. - """ - assert shape['in_shape'][0] == shape['out_shape'][0] - assert shape['out_shape'][1] == shape['in_shape'][1] - assert len(shape['in_shape']) == 4 - assert len(shape['out_shape']) == 2 - - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - assert mask.mask_index[2] is None - assert mask.mask_index[3] is None - module_masks.set_input_mask(mask) - - output_cmask = CoarseMask(num_dim=2) - output_cmask.add_index_mask(dim=1, index=mask.mask_index[1]) - module_masks.set_output_mask(output_cmask) - return output_cmask - - -def mean_outshape(module_masks, mask, shape): - """ - Similar to view operation, currently mask inference only supports - the mean operation on the 3rd and 4th dimensions. - """ - assert shape['in_shape'][0] == shape['out_shape'][0] - assert shape['out_shape'][1] == shape['in_shape'][1] - assert len(shape['in_shape']) == 4 - assert len(shape['out_shape']) == 2 - - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - module_masks.set_output_mask(mask) - - input_cmask = CoarseMask(num_dim=4) - input_cmask.add_index_mask(dim=1, index=mask.mask_index[1]) - module_masks.set_input_mask(input_cmask) - return input_cmask - - -def maxpool2d_inshape(module_masks, mask): - """ - Assume only the second dimension is masked - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the maxpool2d - mask : CoarseMask - The mask of its input tensor - - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - assert mask.mask_index[2] is None - assert mask.mask_index[3] is None - if module_masks.input_mask is not None: - assert module_masks.input_mask <= mask - # assert module_masks.input_mask is None - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - return mask - - -def maxpool2d_outshape(module_masks, mask): - """ - Assume only the second dimension is masked - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the maxpool2d - mask : CoarseMask - The mask of its input tensor - - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - return mask - -def prelu_inshape(module_masks, mask): - """ - We assume only the second dimension has coarse grained mask - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the PReLU - mask : CoarseMask - The mask of its input tensor - - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - assert mask.mask_index[2] is None - assert mask.mask_index[3] is None - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - weight_cmask = CoarseMask(num_dim=1) - weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) - module_masks.set_param_masks('weight', weight_cmask) - return mask - -def prelu_outshape(module_masks, mask): - """ - We assume only the second dimension has coarse grained mask - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the PReLU - mask : CoarseMask - The mask of its input tensor - - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - assert mask.mask_index[2] is None - assert mask.mask_index[3] is None - - weight_cmask = CoarseMask(num_dim=4) - weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) - module_masks.set_param_masks('weight', weight_cmask) - - return mask - - -def relu_inshape(module_masks, mask): - """ - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the relu - mask : CoarseMask - The mask of its input tensor - - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - if module_masks.input_mask is not None: - # mask conflict should be solved before speedup - assert module_masks.input_mask <= mask - # assert module_masks.input_mask is None, "A relu op can only be processed once" - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - return mask - - -def relu_outshape(module_masks, mask): - """ - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the relu - mask : CoarseMask - The mask of its input tensor - - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - if module_masks.output_mask is not None: - # mask conflict should be solved before speedup - assert all( - module_masks.output_mask.mask_index[1] == mask.mask_index[1]) - module_masks.set_input_mask(mask) - module_masks.set_output_mask(mask) - return mask - - -def batchnorm2d_mask(module_masks, mask): - """ - Infer input and output shape from weight mask - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the batchnorm2d - mask : dict - The mask of its weights, from the user provided mask file - Returns - ------- - CoarseMask, CoarseMask - The mask of its input tensor, the mask of its output tensor - """ - assert 'weight' in mask and 'bias' in mask - sum_mask = mask['weight'] + mask['bias'] - nonzero_index = torch.nonzero(sum_mask, as_tuple=True)[0] - # infer shape of parameters - param_cmask = CoarseMask(num_dim=1) - param_cmask.add_index_mask(dim=0, index=nonzero_index) - module_masks.set_param_masks('weight', param_cmask) - module_masks.set_param_masks('bias', param_cmask) - # infer shape of input tensor - input_cmask = CoarseMask(num_dim=4) - input_cmask.add_index_mask(dim=1, - index=torch.nonzero(mask['weight'], as_tuple=True)[0]) - module_masks.set_input_mask(input_cmask) - # infer shape of output tensor - output_cmask = CoarseMask(num_dim=4) - output_cmask.add_index_mask(dim=1, index=nonzero_index) - module_masks.set_output_mask(output_cmask) - return input_cmask, output_cmask - - -def linear_mask(module_masks, mask, shape): - """ - Infer input and output shape from weight mask with limitations: - Only support infer input mask - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the Linear - mask : dict - The mask of its weights, from the user provided mask file - shape: dict - Shape of its input and output tensors - Returns - ------- - CoarseMask, CoarseMask - The mask of its input tensor, the mask of its output tensor - """ - - assert 'weight' in mask - num_input_dim = len(shape['in_shape']) - - # Input data of Linear module can have multiple dimensions. - # here we only support infer coarse mask on the first dimension (dimension 0) - nonzero_index = torch.nonzero(mask['weight'].sum(0), as_tuple=True)[0] - - # infer shape of input tensor - input_cmask = CoarseMask(num_dim=num_input_dim) - input_cmask.add_index_mask(dim=num_input_dim-1, index=nonzero_index) - - module_masks.set_input_mask(input_cmask) - return input_cmask, None - - -def conv2d_mask(module_masks, mask): - """ - Infer input and output shape from weight mask - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the conv2d - mask : dict - The mask of its weights, from the user provided mask file - - Returns - ------- - CoarseMask, CoarseMask - The mask of its input tensor, the mask of its output tensor - """ - def convert_to_coarse_mask(mask, dim=0): - """ - Parameters - ---------- - mask : dict - Weight mask from user provided mask file - dim: int - 0: filter pruning - 1: channel pruning - - Returns - ------- - LongTensor, CoarseMask, CoarseMask - Index of the masked dimension, weight mask, bias mask - """ - assert 'weight' in mask - assert isinstance(mask['weight'], torch.Tensor) - assert dim in [0, 1] - - weight_mask = mask['weight'] - - sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3) - index = torch.nonzero(weight_mask.abs().sum( - sum_idx) != 0, as_tuple=True)[0] - - index = index.long().to(weight_mask.device) - weight_cmask = CoarseMask(num_dim=4) - weight_cmask.add_index_mask(dim=dim, index=index) - bias_cmask = None - if dim == 0 and 'bias' in mask and mask['bias'] is not None: - bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0] - assert torch.all(torch.eq(index, bias_index)), \ - "bias mask should be consistent with weight mask" - bias_cmask = CoarseMask(num_dim=1) - bias_cmask.add_index_mask(dim=0, index=bias_index) - return index, weight_cmask, bias_cmask - - index, weight_cmask, bias_cmask = convert_to_coarse_mask( - mask, dim=conv_prune_dim) - - if index is None: - # TODO: fine grained mask speedup - return None, None - # deal with coarse grain mask - # mask conflict should be solved by fix_mask_conflict before speedup - if 'weight' in module_masks.param_masks: - assert module_masks.param_masks['weight'] == weight_cmask - else: - module_masks.set_param_masks('weight', weight_cmask) - if conv_prune_dim == 0: - module_masks.set_param_masks('bias', bias_cmask) - - io_cmask = CoarseMask(num_dim=4) - io_cmask.add_index_mask(dim=1, index=index) - - if conv_prune_dim == 0: - if module_masks.output_mask is None: - module_masks.set_output_mask(io_cmask) - else: - assert module_masks.output_mask == io_cmask - return None, module_masks.output_mask - else: - if module_masks.input_mask is None: - module_masks.set_input_mask(io_cmask) - else: - assert module_masks.input_mask == io_cmask - return module_masks.input_mask, None - - -def conv2d_inshape(module_masks, mask): - """ - Shape change of input tensor does not affect the shape of its output tensor - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the conv2d - mask : CoarseMask - The mask of its input tensor - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - if module_masks.input_mask is None: - module_masks.set_input_mask(mask) - else: - # the same conv layer may be accessed more - # than once, such as a concat operation. - # mask conflict should be solved by fix_mask_conflict before speedup - - assert module_masks.input_mask == mask - - # shape changes pass through depths wise conv layers - m = module_masks.module - if m.in_channels == m.out_channels == m.groups: - module_masks.output_mask = mask - module_masks.input_mask = mask - return mask - return None - - -def conv2d_outshape(module_masks, mask): - """ - Assume only the second dimension is masked - - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the conv2d - mask : CoarseMask - The mask of its output tensor - - Returns - ------- - CoarseMask - The mask of its input tensor - """ - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - assert mask.mask_index[2] is None - assert mask.mask_index[3] is None - - if module_masks.output_mask is None: - module_masks.output_mask = mask - else: - # mask conflict should be solved by fix_mask_conflict before speedup - # mask and module_masks.output_mask may have different number of dimensions - # since they could be passed by linear or conv2d - assert all( - module_masks.output_mask.mask_index[1] == mask.mask_index[1]) - - weight_cmask = CoarseMask(num_dim=4) - weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) - bias_cmask = CoarseMask(num_dim=1) - bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) - module_masks.set_param_masks('weight', weight_cmask) - module_masks.set_param_masks('bias', bias_cmask) - - # shape changes pass through depths wise conv layers - m = module_masks.module - if m.in_channels == m.out_channels == m.groups: - module_masks.output_mask = mask - module_masks.input_mask = mask - return mask - return None - - -def convtranspose2d_mask(module_masks, mask): - # TODO support the Convtranspose2d Pruning for the L1FilterPruner - raise Exception( - "Current Filter pruner cannot prune the ConvTranspose2d, will support pruning ConvTranspose2d later") - - -def convtranspose2d_inshape(module_masks, mask): - """ - Shape change of input tensor does not affect the shape of its output tensor - Parameters - ---------- - module_masks : ModuleMasks - The ModuleMasks instance of the conv2d - mask : CoarseMask - The mask of its input tensor - Returns - ------- - CoarseMask - The mask of its output tensor - """ - assert isinstance(mask, CoarseMask) - if module_masks.input_mask is None: - module_masks.set_input_mask(mask) - else: - # the same conv layer may be accessed more - # than once, such as a concat operation. - # mask conflict should be solved by fix_mask_conflict before speedup - assert module_masks.input_mask == mask - - # shape changes pass through depths wise conv layers - m = module_masks.module - if m.in_channels == m.out_channels == m.groups: - module_masks.output_mask = mask - module_masks.input_mask = mask - return mask - return None - - -def convtranspose2d_outshape(module_masks, mask): - assert isinstance(mask, CoarseMask) - assert mask.mask_index[1] is not None - assert mask.mask_index[0] is None - assert mask.mask_index[2] is None - assert mask.mask_index[3] is None - - if module_masks.output_mask is None: - module_masks.output_mask = mask - else: - # mask conflict should be solved by fix_mask_conflict before speedup - # mask and module_masks.output_mask may have different number of dimensions - # since they could be passed by linear or conv2d - assert all( - module_masks.output_mask.mask_index[1] == mask.mask_index[1]) - - weight_cmask = CoarseMask(num_dim=4) - # Note the memory layout of Convtranspose2d is C_in, C_out, k1, k2 - weight_cmask.add_index_mask(dim=1, index=mask.mask_index[1]) - bias_cmask = CoarseMask(num_dim=1) - bias_cmask.add_index_mask(dim=0, index=mask.mask_index[1]) - module_masks.set_param_masks('weight', weight_cmask) - module_masks.set_param_masks('bias', bias_cmask) - - # shape changes pass through depths wise conv layers - m = module_masks.module - if m.in_channels == m.out_channels == m.groups: - module_masks.output_mask = mask - module_masks.input_mask = mask - return mask - return None diff --git a/nni/compression/pytorch/speedup/jit_translate.py b/nni/compression/pytorch/speedup/jit_translate.py new file mode 100644 index 0000000000..ac051c73af --- /dev/null +++ b/nni/compression/pytorch/speedup/jit_translate.py @@ -0,0 +1,553 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import re +import logging +from functools import partial +import torch + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def translate_list(list_node, speedup=None): + """ + Get the list of values from the list construct node. + Parameters + --------- + list_node: Torch.C.Value + The cpp node of the target list. + speedup: ModuleSpeed + The Module speedup module. + Returns + ------- + values: list + The list of values in the target cpp list node. + """ + # the node that create the list + create_node = list_node.node() + assert create_node.kind() == 'prim::ListConstruct' + inputs = list(create_node.inputs()) + values = [] + for _i in inputs: + debugName = _i.debugName() + if speedup is not None and debugName in speedup.internal_result: + # this value is the result of the other nodes, such as + # ate::size + values.append(speedup.internal_result[debugName].item()) + else: + # if the corresponding value is a constant + values.append(_i.toIValue()) + return values + + +def parse_constant(cvalue, speedup): + """ + Parse the constant values from this Node + Parameters + ---------- + cvalue: Torch.C.Value + The cpp node of the target constant value. + speedup: ModelSpeedup + The Model speedup module. + Returns + ------- + value: int/float/tensor + The constant values parsed from the node. + """ + logger.debug('Try to parse the constant value: %s', cvalue.debugName()) + if cvalue.toIValue() is not None: + return cvalue.toIValue() + if cvalue.debugName() in speedup.internal_result: + return speedup.internal_result[cvalue.debugName()] + # Get the operator node of the this value + op_node = cvalue.node() + + inputs = op_node.inputs() + input_values = [parse_constant(_i, speedup) for _i in inputs] + func = trans_from_jit_to_python[op_node.kind()](op_node, speedup) + return func(*input_values) + + +def dropout_python(node, speedup): + return torch.dropout + + +def flatten_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + start_dim = inputs[1].toIValue() + end_dim = inputs[2].toIValue() + new_flatten = partial(torch.flatten, start_dim=start_dim, end_dim=end_dim) + return new_flatten + + +def relu_inplace_python(node, speedup): + return torch.relu_ + + +def relu_python(node, speedup): + return torch.relu + + +def sigmoid_python(node, speedup): + return torch.sigmoid + + +def mean_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + dim_list = translate_list(inputs[1], speedup) + keep_dim = inputs[2].toIValue() + new_mean = partial(torch.mean, dim=tuple(dim_list), keepdim=keep_dim) + return new_mean + + +def add_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + constant = None + for i in range(2): + input_i = inputs[i] + debug_name = input_i.debugName() + if debug_name not in speedup.internal_result: + # this input is a constant value + # TODO: what if this input is a constant tensor + + if input_i.toIValue() is not None: + constant = parse_constant(input_i, speedup) + break + if constant is None: + return torch.add + else: + new_add = partial(torch.add, constant) + return new_add + + +def floor_div_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + divisor = inputs[1] + constant = None + if divisor.debugName() not in speedup.internal_result: + # divisor is a constant value/tensor + constant = parse_constant(divisor, speedup) + if constant is None: + return torch.floor_divide + else: + new_op = partial(torch.floor_divide, other=constant) + return new_op + + +def mul_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + constant = None + for i in range(2): + input_i = inputs[i] + debug_name = input_i.debugName() + if debug_name not in speedup.internal_result: + constant = parse_constant(input_i, speedup) + # both two inputs cannot be constants at the same time + break + if constant is None: + return torch.mul + else: + new_mul = partial(torch.mul, constant) + return new_mul + + +def transpose_python(node, speedup): + return torch.t + + +def transpose2_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + dim_1 = inputs[1].toIValue() + dim_2 = inputs[2].toIValue() + new_transpose = partial(torch.transpose, dim0=dim_1, dim1=dim_2) + return new_transpose + + +def matmul_python(node, speedup): + return torch.matmul + + +def div_python(node, speedup): + # The second input parameter of torch.div can be a + # tensor or a constant, if it is a constant, we need + # to return + c_node = node.key_node + inputs = list(c_node.inputs()) + if inputs[1].debugName() in speedup.internal_result: + # the second input parameters is the output of the other + # nodes + return torch.div + else: + other = inputs[1].toIValue() + new_div = partial(torch.div, other=other) + + return new_div + + +def softmax_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + dim = inputs[1].toIValue() + new_softmax = partial(torch.softmax, dim=dim) + return new_softmax + + +def contiguous_python(node, speedup): + class contiguousModule(torch.nn.Module): + def forward(self, x): + return x.contiguous() + return contiguousModule() + + +def gelu_python(node, speedup): + return torch.nn.GELU() + + +def avgpool2d_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + kernel_size = translate_list(inputs[1], speedup) + stride = translate_list(inputs[2], speedup) + padding = translate_list(inputs[3], speedup) + new_avgpool = partial(torch.nn.functional.avg_pool2d, + kernel_size=kernel_size, stride=stride, padding=padding) + return new_avgpool + + +def adaptive_avgpool_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + output_size = translate_list(inputs[1], speedup) + new_avgpool = torch.nn.AdaptiveAvgPool2d(output_size) + return new_avgpool + + +def tupleunpack_python(node, speedup): + # Note: tuple unpack should only exists at the + # the end of the model, and is no need to replace/propagate mask + return None + + +def num2tensor_python(node, speedup): + return torch.nn.Identity() + + +def exp_python(node, speedup): + return torch.exp + + +def squeeze_python(node, speedup): + c_node = node.key_node + inputs = list(c_node.inputs()) + dim = None + if len(inputs) > 1: + dim = parse_constant(inputs[1], speedup) + new_squeeze = partial(torch.squeeze, dim=dim) + return new_squeeze + +########################################################## +# Split Line +# Following module/functions cannot be translated into a +# single function, so we use torch.nn.Module to wrap the +# the core function, and return the torch.nn.Module instead +########################################################## + + +def slice_python(node, speedup): + class SliceMoudle(torch.nn.Module): + def __init__(self, sliceobj): + super(SliceMoudle, self).__init__() + self.sliceobj = sliceobj + + def forward(self, x, *args): + # args is for the slice dimension and indexes, however, + # we already get them from the cpp nodes. Note, though, we + # don't need the slice indexes any more, we cannot remove this + # parameter here, because, there may be multiple inputs passed from + # previous nodes such as aten::size + logger.info('Model has Slice operation, and the operand size=%s, Slice object:%s', str( + x.size()), str(self.sliceobj)) + return x[self.sliceobj] + + c_node = node.key_node + inputs = list(c_node.inputs()) + + slice_dim = parse_constant(inputs[1], speedup) + slice_start = parse_constant(inputs[2], speedup) + slice_end = parse_constant(inputs[3], speedup) + slice_step = parse_constant(inputs[4], speedup) + slice_obj = slice(slice_start, slice_end, slice_step) + slice_list = [] + for _ in range(slice_dim): + slice_list.append(slice(None, None)) + logger.info('Slice dim:%s, Slice obj:%s', str(slice_dim), str(slice_obj)) + slice_list.append(slice_obj) + return SliceMoudle(tuple(slice_list)) + + +def select_python(node, speedup): + class SelectModule(torch.nn.Module): + def __init__(self, dim, index): + super(SelectModule, self).__init__() + self.dim = dim + self.index = index + + def forward(self, x): + return x.select(self.dim, self.index) + c_node = node.key_node + inputs = list(c_node.inputs()) + dim = inputs[1].toIValue() + index = inputs[2].toIValue() + return SelectModule(dim, index) + + +def size_python(node, speedup): + # return None + class SizeMoudle(torch.nn.Module): + def __init__(self, sizedim): + super(SizeMoudle, self).__init__() + self.sizedim = sizedim + + def forward(self, x): + return torch.as_tensor([x.size(self.sizedim)], dtype=torch.long) + # return torch.tensor(x.size(self.sizedim)) + c_node = node.key_node + inputs = list(c_node.inputs()) + size_dim = inputs[1].toIValue() + return SizeMoudle(size_dim) + + +def toint_python(node, speedup): + class ToIntModule(torch.nn.Module): + def forward(self, x): + return x.to(torch.int) + return ToIntModule() + + +def view_python(node, speedup): + class ViewModule(torch.nn.Module): + def __init__(self, shape): + super(ViewModule, self).__init__() + self.shape = shape + logger.info('View Module output size: %s', str(self.shape)) + + def forward(self, *args): + return args[0].view(self.shape) + c_node = node.key_node + inputs = list(c_node.inputs()) + shape = translate_list(inputs[1], speedup) + return ViewModule(shape) + + +def reshape_python(node, speedup): + class ReshapeModule(torch.nn.Module): + def __init__(self, shape): + super(ReshapeModule, self).__init__() + self.shape = shape + logger.info('Reshape Module output size: %s', str(self.shape)) + + def forward(self, *args): + return args[0].view(self.shape) + c_node = node.key_node + inputs = list(c_node.inputs()) + shape = translate_list(inputs[1], speedup) + return ReshapeModule(shape) + + +def permute_python(node, speedup): + class PermuteModule(torch.nn.Module): + def __init__(self, dimlist): + super(PermuteModule, self).__init__() + self.dimlist = dimlist + + def forward(self, x): + return x.permute(self.dimlist) + c_node = node.key_node + inputs = list(c_node.inputs()) + dim_list = translate_list(inputs[1], speedup) + return PermuteModule(dim_list) + + +def getattr_python(node, speedup): + """ + Note: Ops started with Prim:: is not taken as the key node, + so we directly pass the Cpp node into this funciton. + Parameters + ---------- + node: torch._C.Node + The cpp node of prim::Getattr + speedup: ModelSpeedup + The corresponding speedup object. + """ + class GetModule(torch.nn.Module): + def __init__(self, key): + super(GetModule, self).__init__() + self.key = key + + def forward(self, obj): + logger.info('Get attribute: %s', self.key) + return getattr(obj, self.key) + # get the name of the attribute, for example + # prim::GetAttr[name="module_list"](%self.1) + assert node.kind() == 'prim::GetAttr' + pattern = '\[name=\"(.*?)\"\]' + key_words = re.findall(pattern, str(node)) + assert len(key_words) == 1 + return GetModule(key_words[0]) + + +def upsample_bilinear2d_python(node, speedup): + class UpsampleModule(torch.nn.Module): + def __init__(self, size_list, scale_list): + super(UpsampleModule, self).__init__() + self.size_list = size_list + self.scale_list = scale_list + + def forward(self, *args): + """ + The first input of args is the target tensor to upsample + , the following parameters is useless, because we already + get the size_list and the scale_list by parsing the cpp_nodes. + """ + return torch.nn.functional.upsample_bilinear(args[0], + size=self.size_list, scale_factor=self.scale_list) + c_node = node.key_node + inputs = list(c_node.inputs()) + size_list_node = inputs[1].node() + scale_list_node = inputs[3].node() + size_list = None + scale_list = None + + if size_list_node.kind() == 'prim::ListConstruct': + size_list = translate_list(inputs[1], speedup) + if scale_list_node.kind() == 'prim::ListConstruct': + scale_list = translate_list(inputs[3], speedup) + return UpsampleModule(size_list, scale_list) + + +def typeas_python(node, speedup): + """ + currently only support type_as float. + TODO: support more types in the type_as, need to figure out + how to get the scalar type from torch._C.TensorType. + """ + class TypeasModule(torch.nn.Module): + def __init__(self, dtype=torch.float): + self.example = torch.zeros(1, dtype=dtype) + + def forward(self, x): + return x.type_as(self.example) + return TypeasModule() + + +def to_python(node, speedup): + # for the time being, only device parameters are supported + class ToModule(torch.nn.Module): + def __init__(self, device): + super(ToModule, self).__init__() + + def forward(self, x): + return x.to(device) + + c_node = node.key_node + inputs = list(c_node.inputs()) + device = inputs[3].toIValue() + return ToModule(device) + + +def cat_python(node, speedup): + class CatModule(torch.nn.Module): + def __init__(self, cat_dim): + super(CatModule, self).__init__() + self.cat_dim = cat_dim + + def forward(self, *args): + return torch.cat(args, dim=self.cat_dim) + + c_node = node.key_node + inputs = list(c_node.inputs()) + dim = inputs[1].toIValue() + return CatModule(dim) + + +trans_from_jit_to_python = { + 'aten::add': add_python, + 'aten::add_': add_python, + 'aten::mul': mul_python, + 'aten::mul_': mul_python, + 'aten::relu': relu_python, + 'aten::relu_': relu_inplace_python, + 'aten::sigmoid': sigmoid_python, + 'aten::sigmoid_': sigmoid_python, + # tanh behaives like relu + 'aten::tanh': relu_python, + 'aten::tanh_': relu_python, + 'aten::flatten': flatten_python, + 'aten::mean': mean_python, + 'aten::dropout': dropout_python, + 'aten::slice': slice_python, + 'aten::select': select_python, + 'aten::size': size_python, + 'aten::t': transpose_python, + 'aten::transpose': transpose2_python, + 'aten::Int': toint_python, + 'aten::view': view_python, + 'aten::reshape': reshape_python, + 'aten::permute': permute_python, + 'aten::matmul': matmul_python, + 'aten::div': div_python, + 'aten::floor_divide': floor_div_python, + 'aten::softmax': softmax_python, + 'aten::contiguous': contiguous_python, + 'aten::gelu': gelu_python, + 'aten::cat': cat_python, + 'aten::avg_pool2d': avgpool2d_python, + 'aten::max_pool2d': avgpool2d_python, + 'aten::adaptive_avg_pool2d': adaptive_avgpool_python, + 'aten::to': to_python, + 'aten::type_as': typeas_python, + 'aten::upsample_bilinear2d': upsample_bilinear2d_python, + 'aten::exp': exp_python, + 'aten::squeeze': squeeze_python, + 'prim::TupleUnpack': tupleunpack_python, + 'prim::ListUnpack': tupleunpack_python, + 'prim::NumToTensor': num2tensor_python, + 'prim::GetAttr': getattr_python + +} + + +def jit_to_python_function(node, speedup): + """ + Return a callable object to inference the mask according to the + node.op_type. + + Parameters + --------- + node: NodeGroup + The target node to inference the mask + speedup: ModelSpeedup + The speedup object of the target model. + + Returns + ------ + func: callable object(nn.Module/function) + Return the translated function that used to inference the mask + , if current op_type is not supported, then we return None. + """ + logger.debug( + 'Translate C function %s into its python version', node.op_type) + if node.op_type not in trans_from_jit_to_python: + logger.error( + '%s is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~', node.op_type) + # return None to skip the mask inference for this node + return None + return trans_from_jit_to_python[node.op_type](node, speedup) diff --git a/nni/compression/pytorch/utils/__init__.py b/nni/compression/pytorch/utils/__init__.py index e69de29bb2..90f60fdd89 100644 --- a/nni/compression/pytorch/utils/__init__.py +++ b/nni/compression/pytorch/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/nni/compression/pytorch/utils/mask_conflict.py b/nni/compression/pytorch/utils/mask_conflict.py index e89372d60e..b797d61f25 100644 --- a/nni/compression/pytorch/utils/mask_conflict.py +++ b/nni/compression/pytorch/utils/mask_conflict.py @@ -4,10 +4,10 @@ import logging import torch import numpy as np -from .shape_dependency import ChannelDependency, GroupDependency, CatPaddingDependency, InputChannelDependency +from .shape_dependency import ChannelDependency, GroupDependency, InputChannelDependency from .utils import get_module_by_name # logging.basicConfig(level = logging.DEBUG) -_logger = logging.getLogger(__name__) +_logger = logging.getLogger('FixMaskConflict') def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): @@ -21,7 +21,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): A dict object that stores the masks or the path of the mask file model : torch.nn.Module model to fix the mask conflict - dummy_input : torch.Tensor + dummy_input : torch.Tensor/list of tensors/dict of tensors input example to trace the model traced : torch._C.torch.jit.TopLevelTracedModule the traced model of the target model, is this parameter is not None, @@ -48,9 +48,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): masks = fix_group_mask.fix_mask() fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced) masks = fix_channel_mask.fix_mask() - padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced) - masks = padding_cat_mask.fix_mask() - return masks, fix_channel_mask.conv_prune_dim + return masks class MaskFix: @@ -78,70 +76,6 @@ def export(self, path): torch.save(self.masks, path) -class CatMaskPadding(MaskFix): - def __init__(self, masks, model, dummy_input=None, traced=None): - """ - CatMaskPadding find the layers whose output tensor is passed - to the same cat operation. The cat operation concatnates the - masks of the input tensors as the output mask, so when some - of the input layers of the cat operation are not pruned, we still - need to pass the masks of these non-pruned layers(the mask are - all ones) to the cat operation to ensure the shape of the output - mask is right. - - Parameters - ---------- - masks : dict - a dict object that stores the masks - model : torch.nn.Module - model to fix the mask conflict - dummy_input : torch.Tensor - input example to trace the model - traced : torch._C.torch.jit.TopLevelTracedModule - the traced model of the target model, is this parameter is not None, - we donnot use the model and dummpy_input to get the trace graph. - """ - super(CatMaskPadding, self).__init__(masks, model, dummy_input, traced) - - def fix_mask(self): - cat_padding_depen = CatPaddingDependency( - self.model, self.dummy_input, self.traced) - name_to_module = {} - for name, module in self.model.named_modules(): - name_to_module[name] = module - depen = cat_padding_depen.dependency_sets - for layers in depen: - device = None - count = 0 - for layer in layers: - if layer in self.masks: - count += 1 - if device is None: - device = self.masks[layer]['weight'].device - if count == 0: - # no layer is pruned - continue - elif count == len(layers): - # all the layers have been pruned - continue - # pad the mask for the non-pruned layers - for layer in layers: - if layer in self.masks: - continue - - module = name_to_module[layer] - w_shape = module.weight.data.size() - w_mask = torch.ones(w_shape).to(device) - b_mask = None - if hasattr(module, 'bias') and module.bias is not None: - # module.bias may be None - b_shape = module.bias.data.size() - b_mask = torch.ones(b_shape).to(device) - self.masks[layer] = {'weight': w_mask, 'bias': b_mask} - - return self.masks - - class GroupMaskConflict(MaskFix): def __init__(self, masks, model=None, dummy_input=None, traced=None): """ @@ -172,9 +106,11 @@ def fix_mask(self): group_depen = GroupDependency( self.model, self.dummy_input, self.traced) depens = group_depen.dependency + min_groups = group_depen.min_groups _logger.info(depens) for layername in depens: - group = depens[layername] + group_max = depens[layername] + group_min = min_groups[layername] if layername not in self.masks: # this layer not pruned continue @@ -187,29 +123,43 @@ def fix_mask(self): # In fine-grained pruning, skip this layer _logger.info('Layers %s using fine-grained pruning', layername) continue - assert shape[0] % group == 0 + assert shape[0] % group_max == 0 # Find the number of masked filter for each group (mini_masked). # Because we have to keep the pruned filter can still # be divided into the same number of groups, so we only can # prune mini_masked filters for each group. - step = shape[0] / group + step = shape[0] / group_max group_masked = [] - for i in range(group): + for i in range(group_max): _start = step * i _end = step * (i + 1) _tmp_list = list( filter(lambda x: _start <= x and x < _end, all_zeros)) group_masked.append(_tmp_list) mini_masked = min([len(x) for x in group_masked]) + need_unmask = set() for gm in group_masked: for i in range(mini_masked, len(gm)): # To keep the output channel number still being divisible to # groups, we set the masks of following filters to be zero. pos = gm[i] - self.masks[layername]['weight'][pos] = torch.ones( - shape[1:]) - if 'bias' in self.masks[layername] and self.masks[layername]['bias'] is not None: - self.masks[layername]['bias'][pos] = 1 + need_unmask.add(pos) + step = shape[0] / group_min + for i in range(group_min): + _start = step * i + _end = step * (i+1) + _tmp_list = list( + filter(lambda x: _start <= x and x < _end, all_zeros)) + if len(_tmp_list) == step: + # if the whole group is removed, then we don't have to unmask for + # the filters in this group + for pos in _tmp_list: + if pos in need_unmask: + need_unmask.remove(pos) + for pos in need_unmask: + self.masks[layername]['weight'][pos] = torch.ones(shape[1:]) + if hasattr(self.masks[layername], 'bias'): + self.masks[layername]['bias'][pos] = 1 return self.masks @@ -234,9 +184,14 @@ def __init__(self, masks, model=None, dummy_input=None, traced=None): super(ChannelMaskConflict, self).__init__( masks, model, dummy_input, traced) self.conv_prune_dim = detect_mask_prune_dim(masks, model) - _logger.info('detected conv prune dim: %s', self.conv_prune_dim) + _logger.info('Dectected conv prune dim" %d', self.conv_prune_dim) def fix_mask(self): + """ + Fix the mask conflict before the mask inference for the layers that + has shape dependencies. This function should be called before the + mask inference of the 'speedup' module. + """ """ Fix the mask conflict before the mask inference for the layers that has shape dependencies. This function should be called before the @@ -274,7 +229,12 @@ def fix_mask(self): if (channel_mask.sum() * (mask.numel() / mask.shape[self.conv_prune_dim])).item() != (mask > 0).sum().item(): fine_grained = True elif type(m).__name__ == 'Linear': - channel_masks.append((mask.abs().sum(0) != 0).int()) + if self.conv_prune_dim == 1: + channel_masks.append( + (mask.abs().sum(0) != 0).int()) + else: + channel_masks.append( + (mask.abs().sum(1) != 0).int()) elif type(m).__name__ == 'BatchNorm2d': channel_masks.append(mask.int()) elif type(m).__name__ == 'ConvTranspose2d': @@ -293,9 +253,7 @@ def fix_mask(self): # no mask means not pruned, equivlent to full masks channel_masks.append(None) if fine_grained: - _logger.info( - 'fine-grained mask detected, skip solving conflict for this set: %s', dset) - continue + _logger.info("Fine-grianed mask detected") if all(x is None for x in channel_masks): continue num_channels_list = [len(x) @@ -306,7 +264,8 @@ def fix_mask(self): for i, dim_mask in enumerate(channel_masks): if dim_mask is None: - channel_masks[i] = torch.ones(num_channels).int().to(device) + channel_masks[i] = torch.ones( + num_channels).int().to(device) # merge masks with 'or' merged_channel_mask = channel_masks[0].clone() @@ -329,19 +288,22 @@ def fix_mask(self): else: new_mask[:, merged_index, :, :] = 1. elif type(m).__name__ == 'Linear': - new_mask[:, merged_index] = 1. + if self.conv_prune_dim == 0: + new_mask[merged_index, :] = 1 + elif self.conv_prune_dim == 1: + new_mask[:, merged_index] = 1. elif type(m).__name__ == 'BatchNorm2d': new_mask = merged_channel_mask.type_as(orig_mask) else: raise RuntimeError( f'unsupported module type: {type(m).__name__}') - self.masks[name]['weight'] = new_mask if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None: if type(m).__name__ == 'Conv2d': assert self.conv_prune_dim == 0 - self.masks[name]['bias'] = merged_channel_mask.type_as( - self.masks[name]['bias']) + if self.conv_prune_dim == 0: + self.masks[name]['bias'] = merged_channel_mask.type_as( + self.masks[name]['bias']) return self.masks @@ -349,14 +311,12 @@ def fix_mask(self): def detect_mask_prune_dim(masks, model): """ Detect how the masks of convolutional layers are pruned. - Parameters ---------- masks: dict A dict object that stores the masks. model: nn.Module Model object which the mask can be applied on. - Returns: ------- How the masks of convolutional layers are pruned, this depends on pruning algorithms, it should diff --git a/nni/compression/pytorch/utils/shape_dependency.py b/nni/compression/pytorch/utils/shape_dependency.py index 6c7491897b..b8e6dc896f 100644 --- a/nni/compression/pytorch/utils/shape_dependency.py +++ b/nni/compression/pytorch/utils/shape_dependency.py @@ -3,18 +3,34 @@ import csv import logging +import numpy as np -__all__ = ['ChannelDependency', 'GroupDependency', - 'CatPaddingDependency', 'InputChannelDependency'] + +__all__ = ['ChannelDependency', 'GroupDependency', 'InputChannelDependency'] CONV_TYPE = 'aten::_convolution' ADD_TYPES = ['aten::add', 'aten::add_'] +MUL_TYPES = ['aten::mul', 'atem::mul_'] CAT_TYPE = 'aten::cat' logger = logging.getLogger('Shape_Dependency') RESHAPE_OPS = [CAT_TYPE, 'aten::view', 'aten::reshape', 'aten::flatten', 'aten::mean'] +def lcm_list(L): + lcm = 1 + for i in L: + lcm = np.lcm(lcm, i) + return lcm + + +def gcd_list(L): + gcd = L[0] + for i in L: + gcd = np.gcd(gcd, i) + return gcd + + class Dependency: def __init__(self, model=None, dummy_input=None, traced_model=None): """ @@ -38,6 +54,35 @@ def export(self, filepath): raise NotImplementedError +def reshape_break_channel_dependency(op_node): + """ + The reshape operations such as (reshape, view, flatten) may break + the channel dependency. We need to check the input parameters of + these reshape operations to check if this reshape node will break + the channel dependency. However, it's complicated to analyze the the input + parameters for each reshape function and infer if it will break the channel + dependency. So currently, we just check if the input channel and the output + channel is the same, if so, then we can say the original reshape function + doesn't want to change the number of the channels, which means the channel + dependency is not broken. In contrast, the original reshap operation wants + to change the number of channels, so it breaks the channel dependency. + + Parameters + ---------- + opnode: NodePyOP + A Op node of the graph. + Returns + ------- + bool + If this operation will break the channel dependency. + """ + in_shape = op_node.auxiliary['in_shape'] + out_shape = op_node.auxiliary['out_shape'] + in_channel = in_shape[1] + out_channel = out_shape[1] + return in_channel != out_channel + + class ChannelDependency(Dependency): def __init__(self, model=None, dummy_input=None, traced_model=None): """ @@ -80,6 +125,9 @@ def _get_parent_layers(self, node): # find the first met conv parent_layers.append(curnode.name) continue + elif curnode.op_type in RESHAPE_OPS: + if reshape_break_channel_dependency(curnode): + continue parents = self.graph.find_predecessors(curnode.unique_name) parents = [self.graph.name_to_node[name] for name in parents] for parent in parents: @@ -176,7 +224,7 @@ def dependency_sets(self): d_sets = [] visited = set() for node in self.graph.nodes_py.nodes_op: - if node.op_type != 'Conv2d' or node in visited: + if (node.op_type != 'Conv2d' and node.op_type != 'Linear') or node in visited: continue tmp_set = set() if node.name not in self.dependency: @@ -190,35 +238,6 @@ def dependency_sets(self): return d_sets -def reshape_break_channel_dependency(op_node): - """ - The reshape operations such as (reshape, view, flatten) may break - the channel dependency. We need to check the input parameters of - these reshape operations to check if this reshape node will break - the channel dependency. However, it's complicated to analyze the the input - parameters for each reshape function and infer if it will break the channel - dependency. So currently, we just check if the input channel and the output - channel is the same, if so, then we can say the original reshape function - doesn't want to change the number of the channels, which means the channel - dependency is not broken. In contrast, the original reshap operation wants - to change the number of channels, so it breaks the channel dependency. - - Parameters - ---------- - opnode: NodePyOP - A Op node of the graph. - Returns - ------- - bool - If this operation will break the channel dependency. - """ - in_shape = op_node.auxiliary['in_shape'] - out_shape = op_node.auxiliary['out_shape'] - in_channel = in_shape[1] - out_channel = out_shape[1] - return in_channel != out_channel - - class InputChannelDependency(ChannelDependency): """ Some pruners may prune the input channel of the convolutional @@ -295,67 +314,6 @@ def build_dependency(self): self.dependency[layer] = dependency_set -class CatPaddingDependency(ChannelDependency): - def __init__(self, model=None, dummy_input=None, traced_model=None): - super(CatPaddingDependency, self).__init__( - model, dummy_input, traced_model) - - def build_dependency(self): - """ - Build the cat padding dependencies. - If the output features of several layers are stitched together - by cat operation, then these layers have cat padding dependencies. - This is because when inferring the cat mask, we need all the input - masks for the cat operation. At this time we need to know the source - of all input vectors of a cat operation. - """ - for node in self.graph.nodes_py.nodes_op: - parent_layers = [] - if node.op_type == CAT_TYPE: - parent_layers = self._get_parent_layers(node) - dependency_set = set(parent_layers) - # merge the dependencies - for parent in parent_layers: - if parent in self.dependency: - dependency_set.update(self.dependency[parent]) - # save the dependencies - for _node in dependency_set: - self.dependency[_node] = dependency_set - - @property - def dependency_sets(self): - d_sets = [] - visited = set() - for nodename in self.dependency: - if nodename in visited: - continue - d_sets.append(self.dependency[nodename]) - return d_sets - - def export(self, filepath): - """ - Export the dependencies into a file. - In the output file, each line contains a set of layers - whose output features are stitched together by the cat - operation. - - output example: - Dependency Set, Layers - set1, Conv1, Conv2 - set2, Conv3, Conv4 - """ - header = ['Dependency Set', 'Layers'] - setid = 0 - with open(filepath, 'w') as csvf: - csv_w = csv.writer(csvf, delimiter=',') - csv_w.writerow(header) - for layers in self.dependency_sets: - setid += 1 - row = ['Set %d' % setid] - row.extend(list(layers)) - csv_w.writerow(row) - - class GroupDependency(Dependency): def __init__(self, model=None, dummy_input=None, traced_model=None): """ @@ -372,6 +330,7 @@ def __init__(self, model=None, dummy_input=None, traced_model=None): if we alreay has the traced graph of the target model, we donnot need to trace the model again. """ + self.min_groups = {} super(GroupDependency, self).__init__(model, dummy_input, traced_model) def _get_parent_convs(self, node): @@ -451,27 +410,33 @@ def build_dependency(self): key: the name of conv layers, value: the minimum value that the number of filters should be divisible to. """ + self.groups = {} for node in self.graph.nodes_py.nodes_op: if node.op_type == 'Conv2d' or node.op_type == 'ConvTranspose2d': group = self._get_conv_groups(node) - - if node.name in self.dependency: + if node.name in self.groups: # the conv layer whose group is larger than 1 will require that # it's number of output channel to be divisible by the number of group. - self.dependency[node.name] = max( - self.dependency[node.name], group) + self.groups[node.name].append(group) else: - self.dependency[node.name] = group + self.groups[node.name] = [group] if group > 1: # for the conv layer whose group is larger than 1, it will require the number # of output channels of their parent conv layer to be divisible by group. parent_convs = self._get_parent_convs(node) for parent in parent_convs: - if parent in self.dependency: - self.dependency[parent] = max( - self.dependency[parent], group) + if parent in self.groups: + self.groups[parent].append(group) else: - self.dependency[parent] = group + self.groups[parent] = [group] + + for name in self.groups: + self.dependency[name] = lcm_list(self.groups[name]) + if min(self.groups[name]) == gcd_list(self.groups[name]): + self.min_groups[name] = min(self.groups[name]) + else: + self.min_groups[name] = 1 + return self.dependency def export(self, filepath): @@ -501,3 +466,110 @@ def export(self, filepath): @property def dependency_sets(self): return self.dependency + + + +class ReshapeDependency(Dependency): + def __init__(self, model=None, dummy_input=None, traced_model=None): + """ + Some model may have the view/reshape functions, such functions may have fixed parameters + and cannot be replaced at all. Therefore, these functions may have some constraints on + their input shapes. In this class, we find the direct input conv/linear layers of these + reshape functions. If you get the shape conflict when run the forward inference on the + speeduped model, please try remove these layers from the pruner config list and try again. + + Parameters + ---------- + model : torch.nn.Module + The model to be analyzed. + data : torch.Tensor + The example input data to trace the network architecture. + traced_model : torch._C.Graph + if we alreay has the traced graph of the target model, we donnot + need to trace the model again. + """ + super(ReshapeDependency, self).__init__( + model, dummy_input, traced_model) + + def _get_parent_layers(self, node): + """ + Find the nearest father conv layers for the target node. + + Parameters + --------- + node : torch._C.Node + target node. + + Returns + ------- + parent_layers: list + nearest father conv/linear layers for the target worknode. + """ + parent_layers = [] + queue = [] + queue.append(node) + while queue: + curnode = queue.pop(0) + if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d': + # find the first met conv + parent_layers.append(curnode.name) + continue + parents = self.graph.find_predecessors(curnode.unique_name) + parents = [self.graph.name_to_node[name] for name in parents] + for parent in parents: + queue.append(parent) + return parent_layers + + def build_dependency(self): + """ + Build the channel dependency for the conv layers + in the model. + """ + # unpack the tuple/list manually before analyze the + # channel dependency + self.graph.unpack_manually() + for node in self.graph.nodes_py.nodes_op: + parent_layers = [] + # find the node that contains aten::add + # or aten::cat operations + if node.op_type in ['aten::view', 'aten::reshape']: + logger.info('Detect reshape-like functions: %s', node.op_type) + parent_layers = self._get_parent_layers(node) + print('Parent layers', parent_layers) + self.dependency[node.unique_name] = parent_layers + + def export(self, filepath): + """ + export the reshape dependencies as a csv file. + + Output example: + Reshape OP, Dependent Layers + model.view.1,layer1.1.conv2,layer1.0.conv2,conv1 + model.mean.1,layer1.0.conv1 + model.reshape.1,layer1.1.conv1 + """ + header = ['Reshape OP', 'Dependent Layers'] + with open(filepath, 'w') as csvf: + csv_w = csv.writer(csvf, delimiter=',') + csv_w.writerow(header) + for reshape_op in self.dependency: + row = [reshape_op].extend(self.dependency[reshape_op]) + csv_w.writerow(row) + + @property + def dependency_sets(self): + """ + Get the list of the dependency set. + + Returns + ------- + dependency_sets : list + list of the dependency sets. For example, + [set(['conv1', 'conv2']), set(['conv3', 'conv4'])] + + """ + d_sets = [] + for reshape_node in self.dependency: + d_sets.extend(self.dependency[reshape_node]) + d_sets = list(set(d_sets)) + return d_sets diff --git a/nni/compression/pytorch/utils/utils.py b/nni/compression/pytorch/utils/utils.py index c687c5e2a6..6def03bc91 100644 --- a/nni/compression/pytorch/utils/utils.py +++ b/nni/compression/pytorch/utils/utils.py @@ -1,5 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import torch +from .shape_dependency import ReshapeDependency + +torch_float_dtype = [torch.float, torch.float16, torch.float32, torch.float64, torch.half, torch.double] +torch_integer_dtype = [torch.uint8, torch.int16, torch.short, torch.int32, torch.long, torch.bool] def get_module_by_name(model, module_name): """ @@ -28,3 +33,50 @@ def get_module_by_name(model, module_name): return model, leaf_module else: return None, None + + +def rand_like_with_shape(shape, ori_t): + """ + Return a new random tensor like the original + tensor. + """ + assert isinstance(ori_t, torch.Tensor) + device = ori_t.device + dtype = ori_t.dtype + require_grad = ori_t.requires_grad + lower_bound = torch.min(ori_t) + higher_bound = torch.max(ori_t) + if dtype in [torch.uint8, torch.int16, torch.short, torch.int16, torch.long, torch.bool]: + return torch.randint(lower_bound, higher_bound+1, shape, dtype=dtype, device=device) + else: + return torch.rand(shape, dtype=dtype, device=device, requires_grad=require_grad) + +def randomize_tensor(tensor, start=1, end=100): + """ + Randomize the target tensor according to the given + range. + """ + assert isinstance(tensor, torch.Tensor) + if tensor.dtype in torch_integer_dtype: + # integer tensor can only be randomized by the torch.randint + # torch.randint(int(start), int(end), tensor.size(), out=tensor.data, dtype=tensor.dtype) + pass + else: + # we can use nn.init.uniform_ to randomize this tensor + # Note: the tensor that with integer type cannot be randomize + # with nn.init.uniform_ + torch.nn.init.uniform_(tensor.data, start, end) + + +def not_safe_to_prune(model, dummy_input): + """ + Get the layers that are safe to prune(will not bring the shape conflict). + + Parameters + ---------- + model: torch.nn.Module + The target model to prune. + dummy_input: torch.Tensor/list of torch.Tensor/tuple of Tensor + """ + reshape_dset = ReshapeDependency(model, dummy_input) + return reshape_dset.dependency_sets \ No newline at end of file diff --git a/test/ut/sdk/test_compression_utils.py b/test/ut/sdk/test_compression_utils.py index 5423f762b0..2f1ab4c70c 100644 --- a/test/ut/sdk/test_compression_utils.py +++ b/test/ut/sdk/test_compression_utils.py @@ -116,7 +116,7 @@ def test_mask_conflict(self): pruner.export_model(ck_file, mask_file) pruner._unwrap_model() # Fix the mask conflict - fixed_mask, _ = fix_mask_conflict(mask_file, net, dummy_input) + fixed_mask = fix_mask_conflict(mask_file, net, dummy_input) # use the channel dependency groud truth to check if # fix the mask conflict successfully diff --git a/test/ut/sdk/test_model_speedup.py b/test/ut/sdk/test_model_speedup.py index 9ce7a7cba9..d564c3274a 100644 --- a/test/ut/sdk/test_model_speedup.py +++ b/test/ut/sdk/test_model_speedup.py @@ -1,7 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import logging import os +import gc import psutil import sys import numpy as np @@ -9,18 +11,20 @@ import torchvision.models as models import torch.nn as nn import torch.nn.functional as F -from torchvision.models.vgg import vgg16 +from torchvision.models.vgg import vgg16, vgg11 from torchvision.models.resnet import resnet18 +from torchvision.models.mobilenet import mobilenet_v2 import unittest from unittest import TestCase, main from nni.compression.pytorch import ModelSpeedup, apply_compression_results -from nni.algorithms.compression.pytorch.pruning import L1FilterPruner +from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, LevelPruner from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker from nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner import DependencyAwarePruner torch.manual_seed(0) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + BATCH_SIZE = 2 # the relative distance RELATIVE_THRESHOLD = 0.01 @@ -105,6 +109,55 @@ def forward(self, x): return x +class TupleUnpack_backbone(nn.Module): + def __init__(self, width): + super(TupleUnpack_backbone, self).__init__() + self.model_backbone = mobilenet_v2( + pretrained=False, width_mult=width, num_classes=3) + + def forward(self, x): + x1 = self.model_backbone.features[:7](x) + x2 = self.model_backbone.features[7:14](x1) + x3 = self.model_backbone.features[14:18](x2) + return [x1, x2, x3] + + +class TupleUnpack_FPN(nn.Module): + def __init__(self): + super(TupleUnpack_FPN, self).__init__() + + self.conv1 = nn.Conv2d(32, 48, kernel_size=( + 1, 1), stride=(1, 1), bias=False) + self.conv2 = nn.Conv2d(96, 48, kernel_size=( + 1, 1), stride=(1, 1), bias=False) + self.conv3 = nn.Conv2d(320, 48, kernel_size=( + 1, 1), stride=(1, 1), bias=False) + + # self.init_weights() + + def forward(self, inputs): + """Forward function.""" + laterals = [] + + laterals.append(self.conv1(inputs[0])) # inputs[0]==x1 + laterals.append(self.conv2(inputs[1])) # inputs[1]==x2 + laterals.append(self.conv3(inputs[2])) # inputs[2]==x3 + + return laterals + + +class TupleUnpack_Model(nn.Module): + def __init__(self): + super(TupleUnpack_Model, self).__init__() + self.backbone = TupleUnpack_backbone(1.0) + self.fpn = TupleUnpack_FPN() + + def forward(self, x): + x1 = self.backbone(x) + out = self.fpn(x1) + return out + + dummy_input = torch.randn(2, 1, 28, 28) SPARSITY = 0.5 MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth' @@ -129,6 +182,7 @@ def generate_random_sparsity(model): 'sparsity': sparsity}) return cfg_list + def generate_random_sparsity_v2(model): """ Only select 50% layers to prune. @@ -139,9 +193,10 @@ def generate_random_sparsity_v2(model): if np.random.uniform(0, 1.0) > 0.5: sparsity = np.random.uniform(0.5, 0.99) cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name], - 'sparsity': sparsity}) + 'sparsity': sparsity}) return cfg_list + def zero_bn_bias(model): with torch.no_grad(): for name, module in model.named_modules(): @@ -231,19 +286,6 @@ def channel_prune(model): class SpeedupTestCase(TestCase): - def test_speedup_vgg16(self): - prune_model_l1(vgg16()) - model = vgg16() - model.train() - ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), MASK_FILE) - ms.speedup_model() - - orig_model = vgg16() - assert model.training - assert model.features[2].out_channels == int( - orig_model.features[2].out_channels * SPARSITY) - assert model.classifier[0].in_features == int( - orig_model.classifier[0].in_features * SPARSITY) def test_speedup_bigmodel(self): prune_model_l1(BigModel()) @@ -253,7 +295,7 @@ def test_speedup_bigmodel(self): mask_out = model(dummy_input) model.train() - ms = ModelSpeedup(model, dummy_input, MASK_FILE) + ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=2) ms.speedup_model() assert model.training @@ -289,7 +331,7 @@ def test_convtranspose_model(self): new_model = TransposeModel() state_dict = torch.load(MODEL_FILE) new_model.load_state_dict(state_dict) - ms = ModelSpeedup(new_model, dummy_input, MASK_FILE) + ms = ModelSpeedup(new_model, dummy_input, MASK_FILE, confidence=2) ms.speedup_model() zero_bn_bias(ori_model) zero_bn_bias(new_model) @@ -297,26 +339,34 @@ def test_convtranspose_model(self): new_out = new_model(dummy_input) ori_sum = torch.sum(ori_out) speeded_sum = torch.sum(new_out) - print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(ori_sum, speeded_sum)) + print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format( + ori_sum, speeded_sum)) assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ - (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) - - # FIXME: This test case might fail randomly, no idea why - # Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282 + (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) - def test_speedup_integration(self): - # skip this test on windows(7GB mem available) due to memory limit - # Note: hack trick, may be updated in the future - if 'win' in sys.platform or 'Win'in sys.platform: - print('Skip test_speedup_integration on windows due to memory limit!') + def test_speedup_integration_small(self): + model_list = ['resnet18', 'mobilenet_v2', 'alexnet'] + self.speedup_integration(model_list) + + def test_speedup_integration_big(self): + model_list = ['vgg11', 'vgg16', 'resnet34', 'squeezenet1_1', + 'densenet121', 'resnet50', 'wide_resnet50_2'] + mem_info = psutil.virtual_memory() + ava_gb = mem_info.available/1024.0/1024/1024 + print('Avaliable memory size: %.2f GB' % ava_gb) + if ava_gb < 8.0: + # memory size is too small that we may run into an OOM exception + # Skip this test in the pipeline test due to memory limitation return + self.speedup_integration(model_list) + def speedup_integration(self, model_list, speedup_cfg=None): Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2] - for model_name in ['resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121' , 'densenet169', - # 'inception_v3' inception is too large and may fail the pipeline - 'resnet50']: - + # for model_name in ['vgg16', 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121', + # # 'inception_v3' inception is too large and may fail the pipeline + # 'resnet50']: + for model_name in model_list: for gen_cfg_func in Gen_cfg_funcs: kwargs = { 'pretrained': True @@ -334,7 +384,10 @@ def test_speedup_integration(self): speedup_model.eval() # random generate the prune config for the pruner cfgs = gen_cfg_func(net) - print("Testing {} with compression config \n {}".format(model_name, cfgs)) + print("Testing {} with compression config \n {}".format( + model_name, cfgs)) + if len(cfgs) == 0: + continue pruner = L1FilterPruner(net, cfgs) pruner.compress() pruner.export_model(MODEL_FILE, MASK_FILE) @@ -345,7 +398,10 @@ def test_speedup_integration(self): zero_bn_bias(speedup_model) data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device) - ms = ModelSpeedup(speedup_model, data, MASK_FILE) + if speedup_cfg is None: + speedup_cfg = {} + ms = ModelSpeedup(speedup_model, data, + MASK_FILE, confidence=2, **speedup_cfg) ms.speedup_model() speedup_model.eval() @@ -355,12 +411,13 @@ def test_speedup_integration(self): ori_sum = torch.sum(ori_out).item() speeded_sum = torch.sum(speeded_out).item() print('Sum of the output of %s (before speedup):' % - model_name, ori_sum) - print('Sum of the output of %s (after speedup):' % - model_name, speeded_sum) + model_name, ori_sum) + print('Sum of the output of %s (after speedup):' % + model_name, speeded_sum) assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) - + print("Collecting Garbage") + gc.collect(2) def test_channel_prune(self): orig_net = resnet18(num_classes=10).to(device) @@ -378,7 +435,7 @@ def test_channel_prune(self): net.eval() data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device) - ms = ModelSpeedup(net, data, MASK_FILE) + ms = ModelSpeedup(net, data, MASK_FILE, confidence=2) ms.speedup_model() ms.bound_model(data) @@ -391,11 +448,56 @@ def test_channel_prune(self): assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \ (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD) + def test_speedup_tupleunpack(self): + """This test is reported in issue3645""" + model = TupleUnpack_Model() + cfg_list = [{'op_types': ['Conv2d'], 'sparsity':0.5}] + dummy_input = torch.rand(2, 3, 224, 224) + pruner = L1FilterPruner(model, cfg_list) + pruner.compress() + model(dummy_input) + pruner.export_model(MODEL_FILE, MASK_FILE) + ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=2) + ms.speedup_model() + + def test_finegrained_speedup(self): + """ Test the speedup on the fine-grained sparsity""" + class MLP(nn.Module): + def __init__(self): + super(MLP, self).__init__() + self.fc1 = nn.Linear(1024, 1024) + self.fc2 = nn.Linear(1024, 1024) + self.fc3 = nn.Linear(1024, 512) + self.fc4 = nn.Linear(512, 10) + + def forward(self, x): + x = x.view(-1, 1024) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.fc4(x) + return x + model = MLP().to(device) + dummy_input = torch.rand(16, 1, 32, 32).to(device) + cfg_list = [{'op_types': ['Linear'], 'sparsity':0.99}] + pruner = LevelPruner(model, cfg_list) + pruner.compress() + print('Original Arch') + print(model) + pruner.export_model(MODEL_FILE, MASK_FILE) + pruner._unwrap_model() + ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=4) + ms.speedup_model() + print("Fine-grained speeduped model") + print(model) + def tearDown(self): if os.path.exists(MODEL_FILE): os.remove(MODEL_FILE) if os.path.exists(MASK_FILE): os.remove(MASK_FILE) + # GC to release memory + gc.collect(2) if __name__ == '__main__':