Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
fix the bug in nni/compression/pytorch/speedup/infer_shape.py, line n…
Browse files Browse the repository at this point in the history
…o. 625-626 (#3588)
  • Loading branch information
Davidxswang authored Apr 28, 2021
1 parent 3725c8f commit f102d5b
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions nni/compression/pytorch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,15 +597,15 @@ def view_outshape(module_masks, mask, shape):
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```flatten``` op
The ModuleMasks instance of the ```view``` op
mask : CoarseMask
The mask of its input tensor
The mask of its output tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its output tensor
The mask of its input tensor
"""
# NOTE: the case constrained by the following four asserts
assert shape['in_shape'][0] == shape['out_shape'][0]
Expand All @@ -620,10 +620,11 @@ def view_outshape(module_masks, mask, shape):

module_masks.set_output_mask(mask)
input_cmask = CoarseMask(num_dim=4)
index = []
index = set()
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
index.add(loc // step_size)
index = sorted(list(index))
input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable
module_masks.set_input_mask(input_cmask)

Expand Down

0 comments on commit f102d5b

Please sign in to comment.