diff --git a/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py b/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py index f74eba2a52..b58477a653 100644 --- a/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py +++ b/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py @@ -94,9 +94,11 @@ class LevelPruner(OneshotPruner): Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Operation types to prune. + optimizer: torch.optim.Optimizer + Optimizer used to train model """ - def __init__(self, model, config_list): - super().__init__(model, config_list, pruning_algorithm='level') + def __init__(self, model, config_list, optimizer=None): + super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer) class SlimPruner(OneshotPruner): """ @@ -108,9 +110,11 @@ class SlimPruner(OneshotPruner): Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only BatchNorm2d is supported in Slim Pruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model """ - def __init__(self, model, config_list): - super().__init__(model, config_list, pruning_algorithm='slim') + def __init__(self, model, config_list, optimizer=None): + super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer) def validate_config(self, model, config_list): schema = CompressorSchema([{ @@ -147,9 +151,11 @@ class L1FilterPruner(_StructuredFilterPruner): Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only Conv2d is supported in L1FilterPruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model """ - def __init__(self, model, config_list): - super().__init__(model, config_list, pruning_algorithm='l1') + def __init__(self, model, config_list, optimizer=None): + super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer) class L2FilterPruner(_StructuredFilterPruner): """ @@ -161,9 +167,11 @@ class L2FilterPruner(_StructuredFilterPruner): Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only Conv2d is supported in L2FilterPruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model """ - def __init__(self, model, config_list): - super().__init__(model, config_list, pruning_algorithm='l2') + def __init__(self, model, config_list, optimizer=None): + super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer) class FPGMPruner(_StructuredFilterPruner): """ @@ -175,9 +183,11 @@ class FPGMPruner(_StructuredFilterPruner): Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only Conv2d is supported in FPGM Pruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model """ - def __init__(self, model, config_list): - super().__init__(model, config_list, pruning_algorithm='fpgm') + def __init__(self, model, config_list, optimizer=None): + super().__init__(model, config_list, pruning_algorithm='fpgm', optimizer=optimizer) class TaylorFOWeightFilterPruner(_StructuredFilterPruner): """ @@ -189,6 +199,8 @@ class TaylorFOWeightFilterPruner(_StructuredFilterPruner): Supported keys: - sparsity : How much percentage of convolutional filters are to be pruned. - op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1): super().__init__(model, config_list, pruning_algorithm='taylorfo', optimizer=optimizer, statistics_batch_num=statistics_batch_num) @@ -203,6 +215,8 @@ class ActivationAPoZRankFilterPruner(_StructuredFilterPruner): Supported keys: - sparsity : How much percentage of convolutional filters are to be pruned. - op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1): super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, \ @@ -218,6 +232,8 @@ class ActivationMeanRankFilterPruner(_StructuredFilterPruner): Supported keys: - sparsity : How much percentage of convolutional filters are to be pruned. - op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model """ def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1): super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, \ diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index 6a8727c9e4..87afb5f23c 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -88,8 +88,9 @@ def test_torch_quantizer_modules_detection(self): def test_torch_level_pruner(self): model = TorchModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] - torch_compressor.LevelPruner(model, configure_list).compress() + torch_compressor.LevelPruner(model, configure_list, optimizer).compress() @tf2 def test_tf_level_pruner(self): @@ -128,7 +129,7 @@ def test_torch_fpgm_pruner(self): model = TorchModel() config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}] - pruner = torch_compressor.FPGMPruner(model, config_list) + pruner = torch_compressor.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01)) model.conv2.module.weight.data = torch.tensor(w).float() masks = pruner.calc_mask(model.conv2) @@ -314,7 +315,7 @@ def test_torch_QAT_quantizer(self): def test_torch_pruner_validation(self): # test bad configuraiton pruner_classes = [torch_compressor.__dict__[x] for x in \ - ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', \ + ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGPPruner',\ 'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']] bad_configs = [ @@ -336,10 +337,11 @@ def test_torch_pruner_validation(self): ] ] model = TorchModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for pruner_class in pruner_classes: for config_list in bad_configs: try: - pruner_class(model, config_list) + pruner_class(model, config_list, optimizer) print(config_list) assert False, 'Validation error should be raised for bad configuration' except schema.SchemaError: diff --git a/src/sdk/pynni/tests/test_pruners.py b/src/sdk/pynni/tests/test_pruners.py index 1fab9b2b2a..9bba85e3e4 100644 --- a/src/sdk/pynni/tests/test_pruners.py +++ b/src/sdk/pynni/tests/test_pruners.py @@ -192,9 +192,7 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer']) elif pruner_name == 'autocompress': pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x) - elif pruner_name in ['level', 'slim', 'fpgm', 'l1', 'l2']: - pruner = prune_config[pruner_name]['pruner_class'](model, config_list) - elif pruner_name in ['agp', 'taylorfo', 'mean_activation', 'apoz']: + else: pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer) pruner.compress()