Skip to content

Commit

Permalink
fix the bug in line no.336 in mask_conflict.py (microsoft#3629)
Browse files Browse the repository at this point in the history
Co-authored-by: Xuesong Wang <wangxuesong@dm-ai.cn>
  • Loading branch information
Davidxswang and Xuesong Wang authored May 19, 2021
1 parent 797b963 commit 03ff374
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion nni/compression/pytorch/utils/mask_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def fix_mask(self):
elif type(m).__name__ == 'Linear':
new_mask[:, merged_index] = 1.
elif type(m).__name__ == 'BatchNorm2d':
new_mask = merged_index.type_as(orig_mask)
new_mask = merged_channel_mask.type_as(orig_mask)
else:
raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
Expand Down

0 comments on commit 03ff374

Please sign in to comment.