From fea376867c05b391c64309cd3f78e9660c9ebfe3 Mon Sep 17 00:00:00 2001 From: linbinskn <756691769@qq.com> Date: Fri, 2 Jul 2021 16:28:08 +0800 Subject: [PATCH 01/11] add global sort for taylor pruner --- .../pruning/structured_pruning_masker.py | 39 +++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py index eb3cc06ebd..c3be121cd0 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py @@ -33,11 +33,12 @@ class StructuredWeightMasker(WeightMasker): """ - def __init__(self, model, pruner, preserve_round=1, dependency_aware=False): + def __init__(self, model, pruner, preserve_round=1, dependency_aware=False, global_sort=False): self.model = model self.pruner = pruner self.preserve_round = preserve_round self.dependency_aware = dependency_aware + self.global_sort = global_sort def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs): """ @@ -60,7 +61,11 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs): depen_kwargs: dict The kw_args for the dependency-aware mode. """ - if not self.dependency_aware: + if self.global_sort: + # if the global_sort switch is on, calculate the mask based + # on global model information + return self._global_calc_mask(sparsity, wrapper, wrapper_idx) + elif not self.dependency_aware: # calculate the mask in the normal way, each layer calculate its # own mask separately return self._normal_calc_mask(sparsity, wrapper, wrapper_idx) @@ -127,6 +132,12 @@ def _get_current_state(self, sparsity, wrapper, wrapper_idx=None): # weight*mask_weight: apply base mask for iterative pruning return mask, weight * mask_weight, num_prune + def _global_calc_mask(self, sparsity, wrapper, wrapper_idx=None): + num_prune = self._get_global_num_prune(wrapper, wrapper_idx) + mask, weight, _ = self._get_current_state( + sparsity, wrapper, wrapper_idx) + return self.get_mask(mask, weight, num_prune, wrapper, wrapper_idx) + def _normal_calc_mask(self, sparsity, wrapper, wrapper_idx=None): """ Calculate the mask of given layer. @@ -471,12 +482,34 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf """ - def __init__(self, model, pruner, statistics_batch_num=1): + def __init__(self, model, pruner, statistics_batch_num=1, global_sort=False): super().__init__(model, pruner) self.statistics_batch_num = statistics_batch_num self.pruner.iterations = 0 self.pruner.set_wrappers_attribute("contribution", None) self.pruner.patch_optimizer(self.calc_contributions) + self.global_sort = global_sort + self.global_threshold = None + + def _get_global_threshold(self): + channel_contribution_list = [] + for wrapper_idx, wrapper in enumerate(self.pruner.get_modules_wrapper()): + channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) + channel_contribution_list.append(channel_contribution) + all_channel_contributions = torch.cat(channel_contribution_list) + k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) + self.global_threshold = torch.topk( + all_channel_contributions.view(-1), k, largest=False)[0].max() + print(f'set global threshold to {self.global_threshold}') + + def _get_global_num_prune(self, wrapper, wrapper_idx): + if self.global_threshold is None: + _get_global_threshold() + weight = wrapper.module.weight.data + filters = weight.size(0) + w_abs = weight.abs() + num_prune = w_abs[w_abs < self.global_threshold].size()[0] + return num_prune def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) From 23998a948bae35587bc3c9fbc97cd5671a81610f Mon Sep 17 00:00:00 2001 From: linbinskn <756691769@qq.com> Date: Fri, 2 Jul 2021 17:39:20 +0800 Subject: [PATCH 02/11] add global sort mode in example --- examples/model_compress/pruning/basic_pruners_torch.py | 6 ++++++ .../compression/pytorch/pruning/dependency_aware_pruner.py | 3 ++- .../compression/pytorch/pruning/iterative_pruner.py | 4 ++-- .../pytorch/pruning/structured_pruning_masker.py | 5 ++--- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/model_compress/pruning/basic_pruners_torch.py b/examples/model_compress/pruning/basic_pruners_torch.py index 1da4c6994f..d76128fae8 100644 --- a/examples/model_compress/pruning/basic_pruners_torch.py +++ b/examples/model_compress/pruning/basic_pruners_torch.py @@ -218,6 +218,10 @@ def trainer(model, optimizer, criterion, epoch): }] else: + if args.global_sort: + print('Enable the global_sort mode') + # only taylor pruner supports global sort mode currently + kw_args['global_sort'] = True if args.dependency_aware: dummy_input = get_dummy_input(args, device) print('Enable the dependency_aware mode') @@ -331,6 +335,8 @@ def trainer(model, optimizer, criterion, epoch): help='target overall target sparsity') parser.add_argument('--dependency-aware', action='store_true', default=False, help='toggle dependency aware mode') + parser.add_argument('--global-sort', action='store_true', default=False, + help='toggle global sort mode') parser.add_argument('--pruner', type=str, default='l1filter', choices=['level', 'l1filter', 'l2filter', 'slim', 'agp', 'fpgm', 'mean_activation', 'apoz', 'taylorfo'], diff --git a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py index c0ca053a7d..6e19a0d0c2 100644 --- a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py @@ -27,7 +27,7 @@ class DependencyAwarePruner(Pruner): """ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False, - dummy_input=None, **algo_kwargs): + global_sort=False, dummy_input=None, **algo_kwargs): super().__init__(model, config_list=config_list, optimizer=optimizer) self.dependency_aware = dependency_aware @@ -56,6 +56,7 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level' model, self, **algo_kwargs) # set the dependency-aware switch for the masker self.masker.dependency_aware = dependency_aware + self.masker.global_sort = global_sort self.set_wrappers_attribute("if_calculated", False) def calc_mask(self, wrapper, wrapper_idx=None): diff --git a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py index d5cd5dfccb..6bb2bef17b 100644 --- a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py @@ -22,7 +22,7 @@ class IterativePruner(DependencyAwarePruner): """ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', trainer=None, criterion=None, - num_iterations=20, epochs_per_iteration=5, dependency_aware=False, dummy_input=None, **algo_kwargs): + num_iterations=20, epochs_per_iteration=5, dependency_aware=False, global_sort=False, dummy_input=None, **algo_kwargs): """ Parameters ---------- @@ -489,7 +489,7 @@ class TaylorFOWeightFilterPruner(IterativePruner): """ def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, - dependency_aware=False, dummy_input=None): + dependency_aware=False, global_sort=False, dummy_input=None): super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer, criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1, epochs_per_iteration=1, dependency_aware=dependency_aware, diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py index c3be121cd0..72e6c75371 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py @@ -482,13 +482,12 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf """ - def __init__(self, model, pruner, statistics_batch_num=1, global_sort=False): + def __init__(self, model, pruner, statistics_batch_num=1): super().__init__(model, pruner) self.statistics_batch_num = statistics_batch_num self.pruner.iterations = 0 self.pruner.set_wrappers_attribute("contribution", None) self.pruner.patch_optimizer(self.calc_contributions) - self.global_sort = global_sort self.global_threshold = None def _get_global_threshold(self): @@ -504,7 +503,7 @@ def _get_global_threshold(self): def _get_global_num_prune(self, wrapper, wrapper_idx): if self.global_threshold is None: - _get_global_threshold() + self._get_global_threshold() weight = wrapper.module.weight.data filters = weight.size(0) w_abs = weight.abs() From 74e7731e1fbd27671eff5a0f3e2cefec5260e4a1 Mon Sep 17 00:00:00 2001 From: linbinskn <756691769@qq.com> Date: Fri, 2 Jul 2021 17:46:59 +0800 Subject: [PATCH 03/11] fix bug --- .../pytorch/pruning/structured_pruning_masker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py index 72e6c75371..fb0c2d4142 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py @@ -506,8 +506,10 @@ def _get_global_num_prune(self, wrapper, wrapper_idx): self._get_global_threshold() weight = wrapper.module.weight.data filters = weight.size(0) - w_abs = weight.abs() - num_prune = w_abs[w_abs < self.global_threshold].size()[0] + channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) + num_prune = channel_contribution[channel_contribution < self.global_threshold].size()[0] + if num_prune == filters: + num_prune -= 1 return num_prune def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): From 763923e147a484a295631ca78115840a5efabc9c Mon Sep 17 00:00:00 2001 From: linbinskn <756691769@qq.com> Date: Fri, 2 Jul 2021 18:59:48 +0800 Subject: [PATCH 04/11] fix bug --- .../compression/pytorch/pruning/dependency_aware_pruner.py | 2 +- .../compression/pytorch/pruning/iterative_pruner.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py index 6e19a0d0c2..ad61825d7b 100644 --- a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py @@ -27,7 +27,7 @@ class DependencyAwarePruner(Pruner): """ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False, - global_sort=False, dummy_input=None, **algo_kwargs): + dummy_input=None, global_sort=False, **algo_kwargs): super().__init__(model, config_list=config_list, optimizer=optimizer) self.dependency_aware = dependency_aware diff --git a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py index 6bb2bef17b..869aa6c16b 100644 --- a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py @@ -22,7 +22,7 @@ class IterativePruner(DependencyAwarePruner): """ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', trainer=None, criterion=None, - num_iterations=20, epochs_per_iteration=5, dependency_aware=False, global_sort=False, dummy_input=None, **algo_kwargs): + num_iterations=20, epochs_per_iteration=5, dependency_aware=False, dummy_input=None, global_sort=False, **algo_kwargs): """ Parameters ---------- @@ -489,7 +489,7 @@ class TaylorFOWeightFilterPruner(IterativePruner): """ def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, - dependency_aware=False, global_sort=False, dummy_input=None): + dependency_aware=False, dummy_input=None, global_sort=False): super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer, criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1, epochs_per_iteration=1, dependency_aware=dependency_aware, From be5f78ca67eea4d32615ad28bdd4cf8d11c5d560 Mon Sep 17 00:00:00 2001 From: linbinskn <756691769@qq.com> Date: Fri, 2 Jul 2021 21:44:43 +0800 Subject: [PATCH 05/11] pass pipeline --- examples/model_compress/pruning/basic_pruners_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_compress/pruning/basic_pruners_torch.py b/examples/model_compress/pruning/basic_pruners_torch.py index d76128fae8..1c7ee673a0 100644 --- a/examples/model_compress/pruning/basic_pruners_torch.py +++ b/examples/model_compress/pruning/basic_pruners_torch.py @@ -362,4 +362,4 @@ def trainer(model, optimizer, criterion, epoch): args.pruner = params['pruner'] args.model = params['model'] - main(args) + main(args) \ No newline at end of file From 8a75737310076fbbc949e205370cd66b042f4801 Mon Sep 17 00:00:00 2001 From: linbinskn <756691769@qq.com> Date: Tue, 6 Jul 2021 10:28:57 +0800 Subject: [PATCH 06/11] update doc --- docs/en_US/Compression/Pruner.rst | 2 ++ .../compression/pytorch/pruning/iterative_pruner.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/docs/en_US/Compression/Pruner.rst b/docs/en_US/Compression/Pruner.rst index 56259742fc..d833873a1f 100644 --- a/docs/en_US/Compression/Pruner.rst +++ b/docs/en_US/Compression/Pruner.rst @@ -334,6 +334,8 @@ TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference `dependency-aware <./DependencyAware.rst>`__ for more details. +What's more, we provide a global-sort mode for this pruner which is aligned with paper implementation. Please set parameter 'global_sort' to True when instantiate TaylorFOWeightFilterPruner. + Usage ^^^^^ diff --git a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py index 869aa6c16b..3391f2cc0f 100644 --- a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py @@ -51,6 +51,9 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', dummy_input: torch.Tensor The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model. + global_sort: bool + If prune the model in a global-sort way. + Only support TaylorFOWeightFilterPruner currently. algo_kwargs: dict Additional parameters passed to pruning algorithm masker class """ @@ -486,6 +489,11 @@ class TaylorFOWeightFilterPruner(IterativePruner): dummy_input : torch.Tensor The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model. + global_sort: bool + Only support TaylorFOWeightFilterPruner currently. + If prune the model in a global-sort way. If it is `True`, this pruner will prune + the model according to the global contributions information which means channel contributions + will be sorted globally and whether specific channel will be pruned depends on global information. """ def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, From a5e600de87a2146ad0a50e93390f820e9117496a Mon Sep 17 00:00:00 2001 From: v-linbin Date: Fri, 9 Jul 2021 16:26:13 +0800 Subject: [PATCH 07/11] fix global sort k calculation --- .../compression/pytorch/pruning/structured_pruning_masker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py index fb0c2d4142..71bf65e1fc 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py @@ -494,7 +494,9 @@ def _get_global_threshold(self): channel_contribution_list = [] for wrapper_idx, wrapper in enumerate(self.pruner.get_modules_wrapper()): channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) - channel_contribution_list.append(channel_contribution) + wrapper_size = wrapper.module.weight.size().numel() + channel_size = wrapper.module.weight.size(0) + channel_contribution_list.append(channel_contribution.expand(wrapper_size / channel_size, channel_size)) all_channel_contributions = torch.cat(channel_contribution_list) k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) self.global_threshold = torch.topk( From 959422a69a52d75020b58d7bfe962a7a80a45ff6 Mon Sep 17 00:00:00 2001 From: v-linbin Date: Sat, 10 Jul 2021 12:23:21 +0800 Subject: [PATCH 08/11] fix different size problem --- .../compression/pytorch/pruning/structured_pruning_masker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py index 71bf65e1fc..e5f2060c1b 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py @@ -496,7 +496,8 @@ def _get_global_threshold(self): channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) wrapper_size = wrapper.module.weight.size().numel() channel_size = wrapper.module.weight.size(0) - channel_contribution_list.append(channel_contribution.expand(wrapper_size / channel_size, channel_size)) + contribution_expand = channel_contribution.expand(wrapper_size / channel_size, channel_size).view(-1) + channel_contribution_list.append(contribution_expand) all_channel_contributions = torch.cat(channel_contribution_list) k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) self.global_threshold = torch.topk( From e0010b289b33ed4b24ef97c40a02961c69720618 Mon Sep 17 00:00:00 2001 From: v-linbin Date: Sun, 11 Jul 2021 09:23:48 +0800 Subject: [PATCH 09/11] fix size issue --- .../compression/pytorch/pruning/iterative_pruner.py | 4 ++-- .../compression/pytorch/pruning/structured_pruning_masker.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py index 3391f2cc0f..004a233d7a 100644 --- a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py @@ -57,7 +57,7 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', algo_kwargs: dict Additional parameters passed to pruning algorithm masker class """ - super().__init__(model, config_list, optimizer, pruning_algorithm, dependency_aware, dummy_input, **algo_kwargs) + super().__init__(model, config_list, optimizer, pruning_algorithm, dependency_aware, dummy_input, global_sort, **algo_kwargs) if isinstance(epochs_per_iteration, list): assert len(epochs_per_iteration) == num_iterations, 'num_iterations should equal to the length of epochs_per_iteration' @@ -501,7 +501,7 @@ def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifyin super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer, criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1, epochs_per_iteration=1, dependency_aware=dependency_aware, - dummy_input=dummy_input) + dummy_input=dummy_input, global_sort=global_sort) def _supported_dependency_aware(self): return True diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py index e5f2060c1b..51ad8318af 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py @@ -496,7 +496,7 @@ def _get_global_threshold(self): channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) wrapper_size = wrapper.module.weight.size().numel() channel_size = wrapper.module.weight.size(0) - contribution_expand = channel_contribution.expand(wrapper_size / channel_size, channel_size).view(-1) + contribution_expand = channel_contribution.expand(int(wrapper_size / channel_size), channel_size).reshape(-1) channel_contribution_list.append(contribution_expand) all_channel_contributions = torch.cat(channel_contribution_list) k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) From 4fccc26aad582fa105ed153343a2d83ea087d267 Mon Sep 17 00:00:00 2001 From: v-linbin Date: Wed, 14 Jul 2021 22:43:17 +0800 Subject: [PATCH 10/11] delete global sort in dependecny aware pruner --- .../pytorch/pruning/dependency_aware_pruner.py | 3 +-- .../compression/pytorch/pruning/iterative_pruner.py | 10 ++++------ .../pytorch/pruning/structured_pruning_masker.py | 1 - 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py index ad61825d7b..c0ca053a7d 100644 --- a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py @@ -27,7 +27,7 @@ class DependencyAwarePruner(Pruner): """ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False, - dummy_input=None, global_sort=False, **algo_kwargs): + dummy_input=None, **algo_kwargs): super().__init__(model, config_list=config_list, optimizer=optimizer) self.dependency_aware = dependency_aware @@ -56,7 +56,6 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level' model, self, **algo_kwargs) # set the dependency-aware switch for the masker self.masker.dependency_aware = dependency_aware - self.masker.global_sort = global_sort self.set_wrappers_attribute("if_calculated", False) def calc_mask(self, wrapper, wrapper_idx=None): diff --git a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py index 004a233d7a..106db71d56 100644 --- a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py @@ -22,7 +22,7 @@ class IterativePruner(DependencyAwarePruner): """ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', trainer=None, criterion=None, - num_iterations=20, epochs_per_iteration=5, dependency_aware=False, dummy_input=None, global_sort=False, **algo_kwargs): + num_iterations=20, epochs_per_iteration=5, dependency_aware=False, dummy_input=None, **algo_kwargs): """ Parameters ---------- @@ -51,13 +51,10 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', dummy_input: torch.Tensor The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model. - global_sort: bool - If prune the model in a global-sort way. - Only support TaylorFOWeightFilterPruner currently. algo_kwargs: dict Additional parameters passed to pruning algorithm masker class """ - super().__init__(model, config_list, optimizer, pruning_algorithm, dependency_aware, dummy_input, global_sort, **algo_kwargs) + super().__init__(model, config_list, optimizer, pruning_algorithm, dependency_aware, dummy_input, **algo_kwargs) if isinstance(epochs_per_iteration, list): assert len(epochs_per_iteration) == num_iterations, 'num_iterations should equal to the length of epochs_per_iteration' @@ -501,7 +498,8 @@ def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifyin super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer, criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1, epochs_per_iteration=1, dependency_aware=dependency_aware, - dummy_input=dummy_input, global_sort=global_sort) + dummy_input=dummy_input) + self.masker.global_sort = global_sort def _supported_dependency_aware(self): return True diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py index 51ad8318af..65e0204dc8 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py @@ -502,7 +502,6 @@ def _get_global_threshold(self): k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) self.global_threshold = torch.topk( all_channel_contributions.view(-1), k, largest=False)[0].max() - print(f'set global threshold to {self.global_threshold}') def _get_global_num_prune(self, wrapper, wrapper_idx): if self.global_threshold is None: From dea48496aed2654ec8f4b3f184cee17e7d739b72 Mon Sep 17 00:00:00 2001 From: v-linbin Date: Thu, 15 Jul 2021 08:39:25 +0800 Subject: [PATCH 11/11] add ut for taylor global-sort --- test/ut/sdk/test_compressor_torch.py | 44 ++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/test/ut/sdk/test_compressor_torch.py b/test/ut/sdk/test_compressor_torch.py index e6e34d30e6..c78fa982b1 100644 --- a/test/ut/sdk/test_compressor_torch.py +++ b/test/ut/sdk/test_compressor_torch.py @@ -219,6 +219,50 @@ def test_torch_taylorFOweight_pruner(self): assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.])) assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ])) + def test_torch_taylorFOweight_pruner_global_sort(self): + """ + After enabling global_sort, taylorFOweight pruner will calculate contributions and rank topk from all + of the conv operators. Then it will prune low contribution filters depends on the global information. + + So if sparsity of conv operator is 0.4, the expected masks should mask out filter 0 and filter 1 together, + this can be verified through: + `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))` + `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))` + """ + + w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2, + np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4]) + w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]]) + + grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1, + np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1]) + + grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)]) + + config_list = [{'sparsity': 0.4, 'op_types': ['Conv2d']}] + + model = TorchModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1, global_sort=True) + + x = torch.rand((1, 1, 28, 28), requires_grad=True) + model.conv1.module.weight.data = torch.tensor(w1).float() + model.conv2.module.weight.data = torch.tensor(w2).float() + + y = model(x) + y.backward(torch.ones_like(y)) + + model.conv1.module.weight.grad.data = torch.tensor(grad1).float() + model.conv2.module.weight.grad.data = torch.tensor(grad2).float() + optimizer.step() + + mask1 = pruner.calc_mask(model.conv1) + mask2 = pruner.calc_mask(model.conv2) + print(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy()) + print(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy()) + assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.])) + assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.])) + def test_torch_QAT_quantizer(self): model = TorchModel() config_list = [{