From d817c3de0e86b8ff64d1fdb2f6e57f407ba93f0b Mon Sep 17 00:00:00 2001 From: Xuesong Wang Date: Tue, 27 Apr 2021 11:23:51 +0800 Subject: [PATCH 1/2] fix the bug in nni/compression/pytorch/speedup/infer_shape.py on line no. 625, in the function of view_outshape --- nni/compression/pytorch/speedup/infer_shape.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/nni/compression/pytorch/speedup/infer_shape.py b/nni/compression/pytorch/speedup/infer_shape.py index 611e223c3c..0cba5f2463 100644 --- a/nni/compression/pytorch/speedup/infer_shape.py +++ b/nni/compression/pytorch/speedup/infer_shape.py @@ -597,15 +597,15 @@ def view_outshape(module_masks, mask, shape): Parameters ---------- module_masks : ModuleMasks - The ModuleMasks instance of the ```flatten``` op + The ModuleMasks instance of the ```view``` op mask : CoarseMask - The mask of its input tensor + The mask of its output tensor shape : dict Original shape of its input and output tensors Returns ------- CoarseMask - The mask of its output tensor + The mask of its input tensor """ # NOTE: the case constrained by the following four asserts assert shape['in_shape'][0] == shape['out_shape'][0] @@ -620,10 +620,11 @@ def view_outshape(module_masks, mask, shape): module_masks.set_output_mask(mask) input_cmask = CoarseMask(num_dim=4) - index = [] + index = set() step_size = shape['in_shape'][2] * shape['in_shape'][3] for loc in mask.mask_index[1]: - index.extend([loc * step_size + i for i in range(step_size)]) + index.add(loc // step_size) + index = list(index) input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable module_masks.set_input_mask(input_cmask) From 513e4fb9c3976253df77e470a13996702c43a4d5 Mon Sep 17 00:00:00 2001 From: Xuesong Wang Date: Tue, 27 Apr 2021 14:07:57 +0800 Subject: [PATCH 2/2] fix the ut by sort the index list --- nni/compression/pytorch/speedup/infer_shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nni/compression/pytorch/speedup/infer_shape.py b/nni/compression/pytorch/speedup/infer_shape.py index 0cba5f2463..7cdaf1c6a4 100644 --- a/nni/compression/pytorch/speedup/infer_shape.py +++ b/nni/compression/pytorch/speedup/infer_shape.py @@ -624,7 +624,7 @@ def view_outshape(module_masks, mask, shape): step_size = shape['in_shape'][2] * shape['in_shape'][3] for loc in mask.mask_index[1]: index.add(loc // step_size) - index = list(index) + index = sorted(list(index)) input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable module_masks.set_input_mask(input_cmask)