Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
fix speedup with CUDA (#2947)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms authored Oct 12, 2020
1 parent 1cd7ad5 commit e5a208b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/sdk/pynni/nni/compression/torch/speedup/infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def view_inshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
output_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask)
return output_cmask

Expand Down Expand Up @@ -609,7 +609,7 @@ def view_outshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
input_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable
module_masks.set_input_mask(input_cmask)

return input_cmask
Expand Down Expand Up @@ -870,7 +870,7 @@ def convert_to_coarse_mask(mask, dim=0):
if index is None:
return None, None, None
else:
index = torch.LongTensor(index).to(weight_mask.device)
index = index.long().to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=dim, index=index)
bias_cmask = None
Expand Down

0 comments on commit e5a208b

Please sign in to comment.