Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[bug bash] issue 2706 #2818

Merged
merged 2 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited):
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d
The ModuleMasks instance of the Conv2d
mask : CoarseMask
The mask of its input tensor
cat_info: dict
Expand Down
5 changes: 4 additions & 1 deletion src/sdk/pynni/nni/compression/torch/utils/mask_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@ def fix_mask(self):
continue
# pad the mask for the non-pruned layers
for layer in layers:
if layer in self.masks:
continue
module = name_to_module[layer]
w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device)
b_mask = None
if hasattr(module, 'bias'):
if hasattr(module, 'bias') and module.bias is not None:
# module.bias may be None
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight':w_mask, 'bias':b_mask}
Expand Down
2 changes: 1 addition & 1 deletion src/sdk/pynni/tests/test_model_speedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_speedup_bigmodel(self):
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)

def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'inception_v3']:
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3']:
Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device)
speedup_model = Model().to(device)
Expand Down