diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning.py index 59396ded96..df65cd372d 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning.py @@ -278,7 +278,8 @@ def _dependency_calc_mask(self, sparsities, wrappers, wrappers_idx, channel_dset sparsity, _w, _w_idx) num_total = current_weight.size(0) if num_total < 2 or num_prune < 1: - return base_mask + masks[name] = base_mask + continue _tmp_mask = self.get_mask( base_mask, current_weight, num_prune, _w, _w_idx, channel_masks)