Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[bug bash] issue 2706 (#2818)
Browse files Browse the repository at this point in the history
* bug bash

* fix one more bug.

Signed-off-by: Ningxin <Ningxin.Zheng@microsoft.com>
  • Loading branch information
zheng-ningxin authored Aug 26, 2020
1 parent 625a72d commit beeea32
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited):
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
The ModuleMasks instance of the Conv2d
mask : CoarseMask
The mask of its input tensor
cat_info: dict
Expand Down
5 changes: 4 additions & 1 deletion src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@ def fix_mask(self):
continue
# pad the mask for the non-pruned layers
for layer in layers:
if layer in self.masks:
continue
module = name_to_module[layer]
w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device)
b_mask = None
if hasattr(module, 'bias'):
if hasattr(module, 'bias') and module.bias is not None:
# module.bias may be None
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight':w_mask, 'bias':b_mask}
Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/tests/test_model_speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_speedup_bigmodel(self):
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)

def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'inception_v3']:
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3']:
Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device)
speedup_model = Model().to(device)
Expand Down

0 comments on commit beeea32

Please sign in to comment.