Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add the support for aten::mul operator. (#2905)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-ningxin authored Sep 24, 2020
1 parent 986d58c commit f1b8cd2
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask),
'ReLU6': lambda module, mask: no_replace(module, mask),
'Sigmoid': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask),
'Dropout': lambda module, mask: no_replace(module, mask),
'Dropout2d': lambda module, mask: no_replace(module, mask),
Expand Down
6 changes: 6 additions & 0 deletions src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,14 @@ def __repr__(self):
infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::hardtanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::hardtanh_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
Expand All @@ -243,6 +245,10 @@ def __repr__(self):
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask),
'aten::add_': lambda module_masks, mask: add_inshape(module_masks, mask),
'aten::add': lambda module_mask, mask: add_inshape(module_mask, mask),
# mul has the similar behaviour with add, they both request
# the input tesors to have the same shape
'aten::mul': lambda module_mask, mask: add_inshape(module_mask, mask),
'aten::mul_': lambda module_mask, mask: add_inshape(module_mask, mask),
'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited),
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def fix_mask(self):
ori_channels = w_shape[0]
for i in channel_remain:
mask['weight'][i] = torch.ones(w_shape[1:])
if hasattr(mask, 'bias'):
if 'bias' in mask and mask['bias'] is not None:
mask['bias'][i] = 1
_logger.info(','.join(dset))
_logger.info('Pruned Filters after fixing conflict:')
Expand Down

0 comments on commit f1b8cd2

Please sign in to comment.