From f1b8cd2cf627edefeeb0f316a6b9e47236557530 Mon Sep 17 00:00:00 2001 From: Ningxin Zheng <49771382+zheng-ningxin@users.noreply.github.com> Date: Thu, 24 Sep 2020 16:08:39 +0800 Subject: [PATCH] Add the support for aten::mul operator. (#2905) --- .../pynni/nni/compression/torch/speedup/compress_modules.py | 1 + src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py | 6 ++++++ src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py b/src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py index 0b349f9d5c..37d6a8e1e1 100644 --- a/src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py +++ b/src/sdk/pynni/nni/compression/torch/speedup/compress_modules.py @@ -15,6 +15,7 @@ 'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask), 'ReLU': lambda module, mask: no_replace(module, mask), 'ReLU6': lambda module, mask: no_replace(module, mask), + 'Sigmoid': lambda module, mask: no_replace(module, mask), 'Linear': lambda module, mask: replace_linear(module, mask), 'Dropout': lambda module, mask: no_replace(module, mask), 'Dropout2d': lambda module, mask: no_replace(module, mask), diff --git a/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py b/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py index 99c9189d9c..368518f3cf 100644 --- a/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py +++ b/src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py @@ -221,12 +221,14 @@ def __repr__(self): infer_from_inshape = { 'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask), + 'Sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::tanh_': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::hardtanh': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::hardtanh_': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask), + 'aten::sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask), 'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask), 'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), @@ -243,6 +245,10 @@ def __repr__(self): 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask), 'aten::add_': lambda module_masks, mask: add_inshape(module_masks, mask), 'aten::add': lambda module_mask, mask: add_inshape(module_mask, mask), + # mul has the similar behaviour with add, they both request + # the input tesors to have the same shape + 'aten::mul': lambda module_mask, mask: add_inshape(module_mask, mask), + 'aten::mul_': lambda module_mask, mask: add_inshape(module_mask, mask), 'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited), 'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape), 'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask), diff --git a/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py b/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py index ffbcfce3ad..3945a961df 100644 --- a/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py +++ b/src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py @@ -284,7 +284,7 @@ def fix_mask(self): ori_channels = w_shape[0] for i in channel_remain: mask['weight'][i] = torch.ones(w_shape[1:]) - if hasattr(mask, 'bias'): + if 'bias' in mask and mask['bias'] is not None: mask['bias'][i] = 1 _logger.info(','.join(dset)) _logger.info('Pruned Filters after fixing conflict:')