diff --git a/docs/en_US/Compression/CompressionReference.rst b/docs/en_US/Compression/CompressionReference.rst index b616b87a9e..50dcc12876 100644 --- a/docs/en_US/Compression/CompressionReference.rst +++ b/docs/en_US/Compression/CompressionReference.rst @@ -34,7 +34,7 @@ Weight Masker .. autoclass:: nni.algorithms.compression.pytorch.pruning.weight_masker.WeightMasker :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.structured_pruning.StructuredWeightMasker +.. autoclass:: nni.algorithms.compression.pytorch.pruning.structured_pruning_masker.StructuredWeightMasker :members: @@ -43,40 +43,40 @@ Pruners .. autoclass:: nni.algorithms.compression.pytorch.pruning.sensitivity_pruner.SensitivityPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.OneshotPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.OneshotPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.LevelPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.LevelPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.SlimPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.L1FilterPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.L1FilterPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.L2FilterPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.L2FilterPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot_pruner.FPGMPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.FPGMPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.IterativePruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.TaylorFOWeightFilterPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.SlimPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.ActivationAPoZRankFilterPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.TaylorFOWeightFilterPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.one_shot.ActivationMeanRankFilterPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.ActivationAPoZRankFilterPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.ActivationMeanRankFilterPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.agp.AGPPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.AGPPruner :members: -.. autoclass:: nni.algorithms.compression.pytorch.pruning.admm_pruner.ADMMPruner +.. autoclass:: nni.algorithms.compression.pytorch.pruning.iterative_pruner.ADMMPruner :members: .. autoclass:: nni.algorithms.compression.pytorch.pruning.auto_compress_pruner.AutoCompressPruner @@ -88,6 +88,9 @@ Pruners .. autoclass:: nni.algorithms.compression.pytorch.pruning.simulated_annealing_pruner.SimulatedAnnealingPruner :members: +.. autoclass:: nni.algorithms.compression.pytorch.pruning.lottery_ticket.LotteryTicketPruner + :members: + Quantizers ^^^^^^^^^^ diff --git a/docs/en_US/Compression/CustomizeCompressor.rst b/docs/en_US/Compression/CustomizeCompressor.rst index f2f4f260c1..103bff818c 100644 --- a/docs/en_US/Compression/CustomizeCompressor.rst +++ b/docs/en_US/Compression/CustomizeCompressor.rst @@ -28,7 +28,7 @@ An implementation of ``weight masker`` may look like this: # mask = ... return {'weight_mask': mask} -You can reference nni provided :githublink:`weight masker ` implementations to implement your own weight masker. +You can reference nni provided :githublink:`weight masker ` implementations to implement your own weight masker. A basic ``pruner`` looks likes this: @@ -52,7 +52,7 @@ A basic ``pruner`` looks likes this: wrapper.if_calculated = True return masks -Reference nni provided :githublink:`pruner ` implementations to implement your own pruner class. +Reference nni provided :githublink:`pruner ` implementations to implement your own pruner class. ---- diff --git a/docs/en_US/Compression/Overview.rst b/docs/en_US/Compression/Overview.rst index 262d9631f1..788aa0ac84 100644 --- a/docs/en_US/Compression/Overview.rst +++ b/docs/en_US/Compression/Overview.rst @@ -14,10 +14,19 @@ NNI provides a model compression toolkit to help user compress and speed up thei * Provide friendly and easy-to-use compression utilities for users to dive into the compression process and results. * Concise interface for users to customize their own compression algorithms. + +Compression Pipeline +-------------------- + +.. image:: ../../img/compression_flow.jpg + :target: ../../img/compression_flow.jpg + :alt: + +The overall compression pipeline in NNI. For compressing a pretrained model, pruning and quantization can be used alone or in combination. + .. note:: Since NNI compression algorithms are not meant to compress model while NNI speedup tool can truly compress model and reduce latency. To obtain a truly compact model, users should conduct `model speedup <./ModelSpeedup.rst>`__. The interface and APIs are unified for both PyTorch and TensorFlow, currently only PyTorch version has been supported, TensorFlow version will be supported in future. - Supported Algorithms -------------------- @@ -26,7 +35,7 @@ The algorithms include pruning algorithms and quantization algorithms. Pruning Algorithms ^^^^^^^^^^^^^^^^^^ -Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and address the over-fitting issue. +Pruning algorithms compress the original network by removing redundant weights or channels of layers, which can reduce model complexity and mitigate the over-fitting issue. .. list-table:: :header-rows: 1 @@ -96,6 +105,7 @@ Model Speedup The final goal of model compression is to reduce inference latency and model size. However, existing model compression algorithms mainly use simulation to check the performance (e.g., accuracy) of compressed model, for example, using masks for pruning algorithms, and storing quantized values still in float32 for quantization algorithms. Given the output masks and quantization bits produced by those algorithms, NNI can really speed up the model. The detailed tutorial of Masked Model Speedup can be found `here <./ModelSpeedup.rst>`__, The detailed tutorial of Mixed Precision Quantization Model Speedup can be found `here <./QuantizationSpeedup.rst>`__. + Compression Utilities --------------------- @@ -110,7 +120,6 @@ NNI model compression leaves simple interface for users to customize a new compr Reference and Feedback ---------------------- - * To `report a bug `__ for this feature in GitHub; * To `file a feature or improvement request `__ for this feature in GitHub; * To know more about `Feature Engineering with NNI <../FeatureEngineering/Overview.rst>`__\ ; diff --git a/docs/en_US/Compression/Pruner.rst b/docs/en_US/Compression/Pruner.rst index 304a56e43e..eb9c32c875 100644 --- a/docs/en_US/Compression/Pruner.rst +++ b/docs/en_US/Compression/Pruner.rst @@ -1,15 +1,11 @@ Supported Pruning Algorithms on NNI =================================== -We provide several pruning algorithms that support fine-grained weight pruning and structural filter pruning. **Fine-grained Pruning** generally results in unstructured models, which need specialized hardware or software to speed up the sparse network. **Filter Pruning** achieves acceleration by removing the entire filter. Some pruning algorithms use one-shot method that prune weights at once based on an importance metric. Other pruning algorithms control the **pruning schedule** that prune weights during optimization, including some automatic pruning algorithms. +We provide several pruning algorithms that support fine-grained weight pruning and structural filter pruning. **Fine-grained Pruning** generally results in unstructured models, which need specialized hardware or software to speed up the sparse network. **Filter Pruning** achieves acceleration by removing the entire filter. Some pruning algorithms use one-shot method that prune weights at once based on an importance metric (It is necessary to finetune the model to compensate for the loss of accuracy). Other pruning algorithms **iteratively** prune weights during optimization, which control the pruning schedule, including some automatic pruning algorithms. -**Fine-grained Pruning** - -* `Level Pruner <#level-pruner>`__ - -**Filter Pruning** - +**One-shot Pruning** +* `Level Pruner <#level-pruner>`__ ((fine-grained pruning)) * `Slim Pruner <#slim-pruner>`__ * `FPGM Pruner <#fpgm-pruner>`__ * `L1Filter Pruner <#l1filter-pruner>`__ @@ -18,7 +14,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a * `Activation Mean Rank Filter Pruner <#activationmeanrankfilter-pruner>`__ * `Taylor FO On Weight Pruner <#taylorfoweightfilter-pruner>`__ -**Pruning Schedule** +**Iteratively Pruning** * `AGP Pruner <#agp-pruner>`__ * `NetAdapt Pruner <#netadapt-pruner>`__ @@ -26,10 +22,9 @@ We provide several pruning algorithms that support fine-grained weight pruning a * `AutoCompress Pruner <#autocompress-pruner>`__ * `AMC Pruner <#amc-pruner>`__ * `Sensitivity Pruner <#sensitivity-pruner>`__ +* `ADMM Pruner <#admm-pruner>`__ **Others** - -* `ADMM Pruner <#admm-pruner>`__ * `Lottery Ticket Hypothesis <#lottery-ticket-hypothesis>`__ Level Pruner @@ -382,11 +377,7 @@ PyTorch code from nni.algorithms.compression.pytorch.pruning import AGPPruner config_list = [{ - 'initial_sparsity': 0, - 'final_sparsity': 0.8, - 'start_epoch': 0, - 'end_epoch': 10, - 'frequency': 1, + 'sparsity': 0.8, 'op_types': ['default'] }] diff --git a/docs/img/compression_flow.jpg b/docs/img/compression_flow.jpg new file mode 100644 index 0000000000..18c6a0d22e Binary files /dev/null and b/docs/img/compression_flow.jpg differ diff --git a/examples/model_compress/.gitignore b/examples/model_compress/.gitignore new file mode 100644 index 0000000000..c2e41e6b0e --- /dev/null +++ b/examples/model_compress/.gitignore @@ -0,0 +1,6 @@ +.pth +.tar.gz +data/ +MNIST/ +cifar-10-batches-py/ +experiment_data/ \ No newline at end of file diff --git a/examples/model_compress/end2end_compression.py b/examples/model_compress/end2end_compression.py new file mode 100644 index 0000000000..062d6351d6 --- /dev/null +++ b/examples/model_compress/end2end_compression.py @@ -0,0 +1,300 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +NNI example for combined pruning and quantization to compress a model. +In this example, we show the compression process to first prune a model, then quantize the pruned model. + +""" +import argparse +import os +import time +import torch +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR +from torchvision import datasets, transforms + +from nni.compression.pytorch.utils.counter import count_flops_params +from nni.compression.pytorch import ModelSpeedup + +from nni.algorithms.compression.pytorch.pruning import L1FilterPruner +from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer + +from models.mnist.naive import NaiveModel +from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT + + +def get_model_time_cost(model, dummy_input): + model.eval() + n_times = 100 + time_list = [] + for _ in range(n_times): + torch.cuda.synchronize() + tic = time.time() + _ = model(dummy_input) + torch.cuda.synchronize() + time_list.append(time.time()-tic) + time_list = time_list[10:] + return sum(time_list) / len(time_list) + + +def train(args, model, device, train_loader, criterion, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break + + +def test(args, model, device, criterion, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += criterion(output, target).item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + acc = 100 * correct / len(test_loader.dataset) + + print('Test Loss: {:.6f} Accuracy: {}%\n'.format( + test_loss, acc)) + return acc + +def test_trt(engine, test_loader): + test_loss = 0 + correct = 0 + time_elasped = 0 + for data, target in test_loader: + output, time = engine.inference(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + time_elasped += time + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + print("Inference elapsed_time (whole dataset): {}s".format(time_elasped)) + +def main(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + os.makedirs(args.experiment_data_dir, exist_ok=True) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=True, download=True, transform=transform), + batch_size=64,) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=False, transform=transform), + batch_size=1000) + + # Step1. Model Pretraining + model = NaiveModel().to(device) + criterion = torch.nn.NLLLoss() + optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr) + scheduler = StepLR(optimizer, step_size=1, gamma=0.7) + flops, params, _ = count_flops_params(model, (1, 1, 28, 28), verbose=False) + + if args.pretrained_model_dir is None: + args.pretrained_model_dir = os.path.join(args.experiment_data_dir, f'pretrained.pth') + + best_acc = 0 + for epoch in range(args.pretrain_epochs): + train(args, model, device, train_loader, criterion, optimizer, epoch) + scheduler.step() + acc = test(args, model, device, criterion, test_loader) + if acc > best_acc: + best_acc = acc + state_dict = model.state_dict() + + model.load_state_dict(state_dict) + torch.save(state_dict, args.pretrained_model_dir) + print(f'Model saved to {args.pretrained_model_dir}') + else: + state_dict = torch.load(args.pretrained_model_dir) + model.load_state_dict(state_dict) + best_acc = test(args, model, device, criterion, test_loader) + + dummy_input = torch.randn([1000, 1, 28, 28]).to(device) + time_cost = get_model_time_cost(model, dummy_input) + + # 125.49 M, 0.85M, 93.29, 1.1012 + print(f'Pretrained model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}') + + # Step2. Model Pruning + config_list = [{ + 'sparsity': args.sparsity, + 'op_types': ['Conv2d'] + }] + + kw_args = {} + if args.dependency_aware: + dummy_input = torch.randn([1000, 1, 28, 28]).to(device) + print('Enable the dependency_aware mode') + # note that, not all pruners support the dependency_aware mode + kw_args['dependency_aware'] = True + kw_args['dummy_input'] = dummy_input + + pruner = L1FilterPruner(model, config_list, **kw_args) + model = pruner.compress() + pruner.get_pruned_weights() + + mask_path = os.path.join(args.experiment_data_dir, 'mask.pth') + model_path = os.path.join(args.experiment_data_dir, 'pruned.pth') + pruner.export_model(model_path=model_path, mask_path=mask_path) + pruner._unwrap_model() # unwrap all modules to normal state + + # Step3. Model Speedup + m_speedup = ModelSpeedup(model, dummy_input, mask_path, device) + m_speedup.speedup_model() + print('model after speedup', model) + + flops, params, _ = count_flops_params(model, dummy_input, verbose=False) + acc = test(args, model, device, criterion, test_loader) + time_cost = get_model_time_cost(model, dummy_input) + print(f'Pruned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {acc: .2f}, Time Cost: {time_cost}') + + # Step4. Model Finetuning + optimizer = optim.Adadelta(model.parameters(), lr=args.pretrain_lr) + scheduler = StepLR(optimizer, step_size=1, gamma=0.7) + + best_acc = 0 + for epoch in range(args.finetune_epochs): + train(args, model, device, train_loader, criterion, optimizer, epoch) + scheduler.step() + acc = test(args, model, device, criterion, test_loader) + if acc > best_acc: + best_acc = acc + state_dict = model.state_dict() + + model.load_state_dict(state_dict) + save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth') + torch.save(state_dict, save_path) + + flops, params, _ = count_flops_params(model, dummy_input, verbose=True) + time_cost = get_model_time_cost(model, dummy_input) + + # FLOPs 28.48 M, #Params: 0.18M, Accuracy: 89.03, Time Cost: 1.03 + print(f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}, Time Cost: {time_cost}') + print(f'Model saved to {save_path}') + + # Step5. Model Quantization via QAT + config_list = [{ + 'quant_types': ['weight', 'output'], + 'quant_bits': {'weight': 8, 'output': 8}, + 'op_names': ['conv1'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output':8}, + 'op_names': ['relu1'] + }, { + 'quant_types': ['weight', 'output'], + 'quant_bits': {'weight': 8, 'output': 8}, + 'op_names': ['conv2'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8}, + 'op_names': ['relu2'] + }] + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + quantizer = QAT_Quantizer(model, config_list, optimizer) + quantizer.compress() + + # Step6. Quantization Aware Training + best_acc = 0 + for epoch in range(1): + train(args, model, device, train_loader, criterion, optimizer, epoch) + scheduler.step() + acc = test(args, model, device, criterion, test_loader) + if acc > best_acc: + best_acc = acc + state_dict = model.state_dict() + + calibration_path = os.path.join(args.experiment_data_dir, 'calibration.pth') + calibration_config = quantizer.export_model(model_path, calibration_path) + print("calibration_config: ", calibration_config) + + # Step7. Model Speedup + batch_size = 32 + input_shape = (batch_size, 1, 28, 28) + engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=32) + engine.compress() + + test_trt(engine, test_loader) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch Example for model comporession') + + # dataset and model + # parser.add_argument('--dataset', type=str, default='mnist', + # help='dataset to use, mnist, cifar10 or imagenet') + # parser.add_argument('--data-dir', type=str, default='./data/', + # help='dataset directory') + parser.add_argument('--pretrained-model-dir', type=str, default=None, + help='path to pretrained model') + parser.add_argument('--pretrain-epochs', type=int, default=10, + help='number of epochs to pretrain the model') + parser.add_argument('--pretrain-lr', type=float, default=1.0, + help='learning rate to pretrain the model') + + parser.add_argument('--experiment-data-dir', type=str, default='./experiment_data', + help='For saving output checkpoints') + parser.add_argument('--log-interval', type=int, default=100, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + # parser.add_argument('--multi-gpu', action='store_true', default=False, + # help='run on mulitple gpus') + # parser.add_argument('--test-only', action='store_true', default=False, + # help='run test only') + + # pruner + # parser.add_argument('--pruner', type=str, default='l1filter', + # choices=['level', 'l1filter', 'l2filter', 'slim', 'agp', + # 'fpgm', 'mean_activation', 'apoz', 'admm'], + # help='pruner to use') + parser.add_argument('--sparsity', type=float, default=0.5, + help='target overall target sparsity') + parser.add_argument('--dependency-aware', action='store_true', default=False, + help='toggle dependency aware mode') + + # finetuning + parser.add_argument('--finetune-epochs', type=int, default=5, + help='epochs to fine tune') + # parser.add_argument('--kd', action='store_true', default=False, + # help='quickly check a single pass') + # parser.add_argument('--kd_T', type=float, default=4, + # help='temperature for KD distillation') + # parser.add_argument('--finetune-lr', type=float, default=0.5, + # help='learning rate to finetune the model') + + # speedup + # parser.add_argument('--speed-up', action='store_true', default=False, + # help='whether to speed-up the pruned model') + + # parser.add_argument('--nni', action='store_true', default=False, + # help="whether to tune the pruners using NNi tuners") + + args = parser.parse_args() + main(args) diff --git a/examples/model_compress/pruning/models/cifar10/resnet.py b/examples/model_compress/models/cifar10/resnet.py similarity index 100% rename from examples/model_compress/pruning/models/cifar10/resnet.py rename to examples/model_compress/models/cifar10/resnet.py diff --git a/examples/model_compress/pruning/models/cifar10/vgg.py b/examples/model_compress/models/cifar10/vgg.py similarity index 100% rename from examples/model_compress/pruning/models/cifar10/vgg.py rename to examples/model_compress/models/cifar10/vgg.py diff --git a/examples/model_compress/pruning/models/mnist/lenet.py b/examples/model_compress/models/mnist/lenet.py similarity index 100% rename from examples/model_compress/pruning/models/mnist/lenet.py rename to examples/model_compress/models/mnist/lenet.py diff --git a/examples/model_compress/models/mnist/naive.py b/examples/model_compress/models/mnist/naive.py new file mode 100644 index 0000000000..4609862527 --- /dev/null +++ b/examples/model_compress/models/mnist/naive.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import reduce + +class NaiveModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) + self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) + self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) + self.fc2 = torch.nn.Linear(500, 10) + self.relu1 = torch.nn.ReLU6() + self.relu2 = torch.nn.ReLU6() + self.relu3 = torch.nn.ReLU6() + self.max_pool1 = torch.nn.MaxPool2d(2, 2) + self.max_pool2 = torch.nn.MaxPool2d(2, 2) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.max_pool1(x) + x = self.relu2(self.conv2(x)) + x = self.max_pool2(x) + x = x.view(-1, x.size()[1:].numel()) + x = self.relu3(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) \ No newline at end of file diff --git a/examples/model_compress/pruning/models/mobilenet.py b/examples/model_compress/models/mobilenet.py similarity index 100% rename from examples/model_compress/pruning/models/mobilenet.py rename to examples/model_compress/models/mobilenet.py diff --git a/examples/model_compress/pruning/models/mobilenet_v2.py b/examples/model_compress/models/mobilenet_v2.py similarity index 100% rename from examples/model_compress/pruning/models/mobilenet_v2.py rename to examples/model_compress/models/mobilenet_v2.py diff --git a/examples/model_compress/pruning/amc/amc_search.py b/examples/model_compress/pruning/amc/amc_search.py index 6e10f554b9..5c861a8887 100644 --- a/examples/model_compress/pruning/amc/amc_search.py +++ b/examples/model_compress/pruning/amc/amc_search.py @@ -12,7 +12,7 @@ from data import get_split_dataset from utils import AverageMeter, accuracy -sys.path.append('../models') +sys.path.append('../../models') def parse_args(): parser = argparse.ArgumentParser(description='AMC search script') diff --git a/examples/model_compress/pruning/amc/amc_train.py b/examples/model_compress/pruning/amc/amc_train.py index 732d3bbae9..eb02c7020a 100644 --- a/examples/model_compress/pruning/amc/amc_train.py +++ b/examples/model_compress/pruning/amc/amc_train.py @@ -22,7 +22,7 @@ from data import get_dataset from utils import AverageMeter, accuracy, progress_bar -sys.path.append('../models') +sys.path.append('../../models') from mobilenet import MobileNet from mobilenet_v2 import MobileNetV2 diff --git a/examples/model_compress/pruning/auto_pruners_torch.py b/examples/model_compress/pruning/auto_pruners_torch.py index d9e0f53824..f32faccfa8 100644 --- a/examples/model_compress/pruning/auto_pruners_torch.py +++ b/examples/model_compress/pruning/auto_pruners_torch.py @@ -13,14 +13,16 @@ from torch.optim.lr_scheduler import StepLR, MultiStepLR from torchvision import datasets, transforms -from models.mnist.lenet import LeNet -from models.cifar10.vgg import VGG -from models.cifar10.resnet import ResNet18, ResNet50 from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, L2FilterPruner, FPGMPruner from nni.algorithms.compression.pytorch.pruning import SimulatedAnnealingPruner, ADMMPruner, NetAdaptPruner, AutoCompressPruner from nni.compression.pytorch import ModelSpeedup from nni.compression.pytorch.utils.counter import count_flops_params +import sys +sys.path.append('../models') +from mnist.lenet import LeNet +from cifar10.vgg import VGG +from cifar10.resnet import ResNet18, ResNet50 def get_data(dataset, data_dir, batch_size, test_batch_size): ''' @@ -67,7 +69,7 @@ def get_data(dataset, data_dir, batch_size, test_batch_size): return train_loader, val_loader, criterion -def train(args, model, device, train_loader, criterion, optimizer, epoch, callback=None): +def train(args, model, device, train_loader, criterion, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) @@ -75,9 +77,6 @@ def train(args, model, device, train_loader, criterion, optimizer, epoch, callba output = model(data) loss = criterion(output, target) loss.backward() - # callback should be inserted between loss.backward() and optimizer.step() - if callback: - callback() optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( @@ -198,8 +197,8 @@ def short_term_fine_tuner(model, epochs=1): for epoch in range(epochs): train(args, model, device, train_loader, criterion, optimizer, epoch) - def trainer(model, optimizer, criterion, epoch, callback): - return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch, callback=callback) + def trainer(model, optimizer, criterion, epoch): + return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch) def evaluator(model): return test(model, device, criterion, val_loader) @@ -264,7 +263,7 @@ def evaluator(model): }] else: raise ValueError('Example only implemented for LeNet.') - pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, training_epochs=2) + pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=2, epochs_per_iteration=2) elif args.pruner == 'SimulatedAnnealingPruner': pruner = SimulatedAnnealingPruner( model, config_list, evaluator=evaluator, base_algo=args.base_algo, @@ -273,7 +272,7 @@ def evaluator(model): pruner = AutoCompressPruner( model, config_list, trainer=trainer, evaluator=evaluator, dummy_input=dummy_input, num_iterations=3, optimize_mode='maximize', base_algo=args.base_algo, - cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_training_epochs=5, + cool_down_rate=args.cool_down_rate, admm_num_iterations=30, admm_epochs_per_iteration=5, experiment_data_dir=args.experiment_data_dir) else: raise ValueError( diff --git a/examples/model_compress/pruning/basic_pruners_torch.py b/examples/model_compress/pruning/basic_pruners_torch.py index c3225353f4..51c2068aa3 100644 --- a/examples/model_compress/pruning/basic_pruners_torch.py +++ b/examples/model_compress/pruning/basic_pruners_torch.py @@ -12,25 +12,24 @@ import argparse import os -import time +import sys import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim from torch.optim.lr_scheduler import StepLR, MultiStepLR from torchvision import datasets, transforms -from models.mnist.lenet import LeNet -from models.cifar10.vgg import VGG +sys.path.append('../models') +from mnist.lenet import LeNet +from cifar10.vgg import VGG from nni.compression.pytorch.utils.counter import count_flops_params import nni -from nni.compression.pytorch import apply_compression_results, ModelSpeedup +from nni.compression.pytorch import ModelSpeedup from nni.algorithms.compression.pytorch.pruning import ( LevelPruner, SlimPruner, FPGMPruner, + TaylorFOWeightFilterPruner, L1FilterPruner, L2FilterPruner, AGPPruner, @@ -38,7 +37,6 @@ ActivationAPoZRankFilterPruner ) - _logger = logging.getLogger('mnist_example') _logger.setLevel(logging.INFO) @@ -50,7 +48,8 @@ 'agp': AGPPruner, 'fpgm': FPGMPruner, 'mean_activation': ActivationMeanRankFilterPruner, - 'apoz': ActivationAPoZRankFilterPruner + 'apoz': ActivationAPoZRankFilterPruner, + 'taylorfo': TaylorFOWeightFilterPruner } def get_dummy_input(args, device): @@ -60,53 +59,6 @@ def get_dummy_input(args, device): dummy_input = torch.randn([args.test_batch_size, 3, 32, 32]).to(device) return dummy_input -def get_pruner(model, pruner_name, device, optimizer=None, dependency_aware=False): - - pruner_cls = str2pruner[pruner_name] - - if pruner_name == 'level': - config_list = [{ - 'sparsity': args.sparsity, - 'op_types': ['default'] - }] - elif pruner_name in ['l1filter', 'mean_activation', 'apoz']: - # Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS', - # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A' - config_list = [{ - 'sparsity': args.sparsity, - 'op_types': ['Conv2d'], - 'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37'] - }] - elif pruner_name == 'slim': - config_list = [{ - 'sparsity': args.sparsity, - 'op_types': ['BatchNorm2d'], - }] - elif pruner_name == 'agp': - config_list = [{ - 'initial_sparsity': 0., - 'final_sparsity': 0.8, - 'start_epoch': 0, - 'end_epoch': 10, - 'frequency': 1, - 'op_types': ['Conv2d'] - }] - else: - config_list = [{ - 'sparsity': args.sparsity, - 'op_types': ['Conv2d'] - }] - - kw_args = {} - if dependency_aware: - dummy_input = get_dummy_input(args, device) - print('Enable the dependency_aware mode') - # note that, not all pruners support the dependency_aware mode - kw_args['dependency_aware'] = True - kw_args['dummy_input'] = dummy_input - - pruner = pruner_cls(model, config_list, optimizer, **kw_args) - return pruner def get_data(dataset, data_dir, batch_size, test_batch_size): kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else { @@ -174,7 +126,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite print('start pre-training...') best_acc = 0 for epoch in range(args.pretrain_epochs): - train(args, model, device, train_loader, criterion, optimizer, epoch, sparse_bn=True if args.pruner == 'slim' else False) + train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() acc = test(args, model, device, criterion, test_loader) if acc > best_acc: @@ -198,12 +150,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite print('Pretrained model acc:', best_acc) return model, optimizer, scheduler -def updateBN(model): - for m in model.modules(): - if isinstance(m, nn.BatchNorm2d): - m.weight.grad.data.add_(0.0001 * torch.sign(m.weight.data)) - -def train(args, model, device, train_loader, criterion, optimizer, epoch, sparse_bn=False): +def train(args, model, device, train_loader, criterion, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) @@ -211,11 +158,6 @@ def train(args, model, device, train_loader, criterion, optimizer, epoch, sparse output = model(data) loss = criterion(output, target) loss.backward() - - if sparse_bn: - # L1 regularization on BN layer - updateBN(model) - optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( @@ -256,64 +198,99 @@ def main(args): flops, params, results = count_flops_params(model, dummy_input) print(f"FLOPs: {flops}, params: {params}") - print('start pruning...') + print(f'start {args.pruner} pruning...') + + def trainer(model, optimizer, criterion, epoch): + return train(args, model, device, train_loader, criterion, optimizer, epoch=epoch) + + pruner_cls = str2pruner[args.pruner] + + kw_args = {} + config_list = [{ + 'sparsity': args.sparsity, + 'op_types': ['Conv2d'] + }] + + if args.pruner == 'level': + config_list = [{ + 'sparsity': args.sparsity, + 'op_types': ['default'] + }] + + else: + if args.dependency_aware: + dummy_input = get_dummy_input(args, device) + print('Enable the dependency_aware mode') + # note that, not all pruners support the dependency_aware mode + kw_args['dependency_aware'] = True + kw_args['dummy_input'] = dummy_input + if args.pruner not in ('l1filter', 'l2filter', 'fpgm'): + # set only work for training aware pruners + kw_args['trainer'] = trainer + kw_args['optimizer'] = optimizer + kw_args['criterion'] = criterion + + if args.pruner in ('slim', 'mean_activation', 'apoz', 'taylorfo'): + kw_args['sparsity_training_epochs'] = 5 + + if args.pruner == 'agp': + kw_args['pruning_algorithm'] = 'l1' + kw_args['num_iterations'] = 5 + kw_args['epochs_per_iteration'] = 1 + + # Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS', + # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A' + if args.pruner == 'slim': + config_list = [{ + 'sparsity': args.sparsity, + 'op_types': ['BatchNorm2d'], + }] + else: + config_list = [{ + 'sparsity': args.sparsity, + 'op_types': ['Conv2d'], + 'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37'] + }] + + pruner = pruner_cls(model, config_list, **kw_args) + + # Pruner.compress() returns the masked model + model = pruner.compress() + pruner.get_pruned_weights() + + # export the pruned model masks for model speedup model_path = os.path.join(args.experiment_data_dir, 'pruned_{}_{}_{}.pth'.format( args.model, args.dataset, args.pruner)) mask_path = os.path.join(args.experiment_data_dir, 'mask_{}_{}_{}.pth'.format( args.model, args.dataset, args.pruner)) - - pruner = get_pruner(model, args.pruner, device, optimizer, args.dependency_aware) - model = pruner.compress() - - if args.multi_gpu and torch.cuda.device_count() > 1: - model = nn.DataParallel(model) + pruner.export_model(model_path=model_path, mask_path=mask_path) if args.test_only: test(args, model, device, criterion, test_loader) + # Unwrap all modules to normal state + pruner._unwrap_model() + m_speedup = ModelSpeedup(model, dummy_input, mask_path, device) + m_speedup.speedup_model() + + print('start finetuning...') best_top1 = 0 + save_path = os.path.join(args.experiment_data_dir, f'finetuned.pth') for epoch in range(args.fine_tune_epochs): - pruner.update_epoch(epoch) print('# Epoch {} #'.format(epoch)) train(args, model, device, train_loader, criterion, optimizer, epoch) scheduler.step() top1 = test(args, model, device, criterion, test_loader) if top1 > best_top1: best_top1 = top1 - # Export the best model, 'model_path' stores state_dict of the pruned model, - # mask_path stores mask_dict of the pruned model - pruner.export_model(model_path=model_path, mask_path=mask_path) + torch.save(model.state_dict(), save_path) + + flops, params, results = count_flops_params(model, dummy_input) + print(f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_top1: .2f}') if args.nni: nni.report_final_result(best_top1) - if args.speed_up: - # reload the best checkpoint for speed-up - args.pretrained_model_dir = model_path - model, _, _ = get_model_optimizer_scheduler(args, device, train_loader, test_loader, criterion) - model.eval() - - apply_compression_results(model, mask_path, device) - - # test model speed - start = time.time() - for _ in range(32): - use_mask_out = model(dummy_input) - print('elapsed time when use mask: ', time.time() - start) - - m_speedup = ModelSpeedup(model, dummy_input, mask_path, device) - m_speedup.speedup_model() - - flops, params, results = count_flops_params(model, dummy_input) - print(f"FLOPs: {flops}, params: {params}") - - start = time.time() - for _ in range(32): - use_speedup_out = model(dummy_input) - print('elapsed time when use speedup: ', time.time() - start) - - top1 = test(args, model, device, criterion, test_loader) - if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch Example for model comporession') @@ -352,17 +329,13 @@ def main(args): help='toggle dependency aware mode') parser.add_argument('--pruner', type=str, default='l1filter', choices=['level', 'l1filter', 'l2filter', 'slim', 'agp', - 'fpgm', 'mean_activation', 'apoz'], + 'fpgm', 'mean_activation', 'apoz', 'taylorfo'], help='pruner to use') # fine-tuning parser.add_argument('--fine-tune-epochs', type=int, default=160, help='epochs to fine tune') - # speed-up - parser.add_argument('--speed-up', action='store_true', default=False, - help='whether to speed-up the pruned model') - parser.add_argument('--nni', action='store_true', default=False, help="whether to tune the pruners using NNi tuners") diff --git a/examples/model_compress/pruning/finetune_kd_torch.py b/examples/model_compress/pruning/finetune_kd_torch.py index 10fccd3484..68c96b4ba3 100644 --- a/examples/model_compress/pruning/finetune_kd_torch.py +++ b/examples/model_compress/pruning/finetune_kd_torch.py @@ -20,8 +20,11 @@ from torch.optim.lr_scheduler import MultiStepLR, StepLR from torchvision import datasets, transforms from basic_pruners_torch import get_data -from models.cifar10.vgg import VGG -from models.mnist.lenet import LeNet + +import sys +sys.path.append('../models') +from cifar10.vgg import VGG +from mnist.lenet import LeNet class DistillKL(nn.Module): """Distilling the Knowledge in a Neural Network""" diff --git a/examples/model_compress/pruning/lottery_torch_mnist_fc.py b/examples/model_compress/pruning/lottery_torch_mnist_fc.py index 7a46c79834..215bc5f5f7 100644 --- a/examples/model_compress/pruning/lottery_torch_mnist_fc.py +++ b/examples/model_compress/pruning/lottery_torch_mnist_fc.py @@ -20,7 +20,7 @@ class fc1(nn.Module): def __init__(self, num_classes=10): super(fc1, self).__init__() self.classifier = nn.Sequential( - nn.Linear(28*28, 300), + nn.Linear(28 * 28, 300), nn.ReLU(inplace=True), nn.Linear(300, 100), nn.ReLU(inplace=True), diff --git a/examples/model_compress/pruning/model_speedup.py b/examples/model_compress/pruning/model_speedup.py index 48aff8702c..bec053542a 100644 --- a/examples/model_compress/pruning/model_speedup.py +++ b/examples/model_compress/pruning/model_speedup.py @@ -5,8 +5,12 @@ import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms -from models.cifar10.vgg import VGG -from models.mnist.lenet import LeNet + +import sys +sys.path.append('../models') +from cifar10.vgg import VGG +from mnist.lenet import LeNet + from nni.compression.pytorch import apply_compression_results, ModelSpeedup torch.manual_seed(0) diff --git a/examples/model_compress/pruning/naive_prune_torch.py b/examples/model_compress/pruning/naive_prune_torch.py index 5509db9aa5..88ff3df6d9 100644 --- a/examples/model_compress/pruning/naive_prune_torch.py +++ b/examples/model_compress/pruning/naive_prune_torch.py @@ -10,15 +10,16 @@ import argparse import torch -import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR -from models.mnist.lenet import LeNet + from nni.algorithms.compression.pytorch.pruning import LevelPruner -import nni +import sys +sys.path.append('../models') +from mnist.lenet import LeNet _logger = logging.getLogger('mnist_example') _logger.setLevel(logging.INFO) @@ -108,7 +109,7 @@ def main(args): 'op_types': ['default'], }] - pruner = LevelPruner(model, prune_config, optimizer_finetune) + pruner = LevelPruner(model, prune_config) model = pruner.compress() # fine-tuning @@ -149,5 +150,4 @@ def main(args): help='target overall target sparsity') args = parser.parse_args() - - main(args) \ No newline at end of file + main(args) diff --git a/examples/model_compress/quantization/BNN_quantizer_cifar10.py b/examples/model_compress/quantization/BNN_quantizer_cifar10.py index 1615a289a4..f6d4c27316 100644 --- a/examples/model_compress/quantization/BNN_quantizer_cifar10.py +++ b/examples/model_compress/quantization/BNN_quantizer_cifar10.py @@ -31,7 +31,6 @@ def __init__(self, num_classes=1000): nn.BatchNorm2d(256, eps=1e-4, momentum=0.1), nn.Hardtanh(inplace=True), - nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(512, eps=1e-4, momentum=0.1), nn.Hardtanh(inplace=True), diff --git a/examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py b/examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py index 18cd059556..10de852570 100644 --- a/examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py +++ b/examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py @@ -3,27 +3,9 @@ from torchvision import datasets, transforms from nni.algorithms.compression.pytorch.quantization import DoReFaQuantizer - -class Mnist(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) - self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) - self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) - self.fc2 = torch.nn.Linear(500, 10) - self.relu1 = torch.nn.ReLU6() - self.relu2 = torch.nn.ReLU6() - self.relu3 = torch.nn.ReLU6() - - def forward(self, x): - x = self.relu1(self.conv1(x)) - x = F.max_pool2d(x, 2, 2) - x = self.relu2(self.conv2(x)) - x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4 * 4 * 50) - x = self.relu3(self.fc1(x)) - x = self.fc2(x) - return F.log_softmax(x, dim=1) +import sys +sys.path.append('../models') +from mnist.naive import NaiveModel def train(model, quantizer, device, train_loader, optimizer): @@ -66,7 +48,7 @@ def main(): datasets.MNIST('data', train=False, transform=trans), batch_size=1000, shuffle=True) - model = Mnist() + model = NaiveModel() model = model.to(device) configure_list = [{ 'quant_types': ['weight'], diff --git a/examples/model_compress/quantization/QAT_torch_quantizer.py b/examples/model_compress/quantization/QAT_torch_quantizer.py index ef14ff5ce0..4ccbe34eb0 100644 --- a/examples/model_compress/quantization/QAT_torch_quantizer.py +++ b/examples/model_compress/quantization/QAT_torch_quantizer.py @@ -3,28 +3,9 @@ from torchvision import datasets, transforms from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer - -class Mnist(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) - self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) - self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) - self.fc2 = torch.nn.Linear(500, 10) - self.relu1 = torch.nn.ReLU6() - self.relu2 = torch.nn.ReLU6() - self.relu3 = torch.nn.ReLU6() - - def forward(self, x): - x = self.relu1(self.conv1(x)) - x = F.max_pool2d(x, 2, 2) - x = self.relu2(self.conv2(x)) - x = F.max_pool2d(x, 2, 2) - x = x.view(-1, 4 * 4 * 50) - x = self.relu3(self.fc1(x)) - x = self.fc2(x) - return F.log_softmax(x, dim=1) - +import sys +sys.path.append('../models') +from mnist.naive import NaiveModel def train(model, quantizer, device, train_loader, optimizer): model.train() @@ -66,7 +47,7 @@ def main(): datasets.MNIST('data', train=False, transform=trans), batch_size=1000, shuffle=True) - model = Mnist() + model = NaiveModel() '''you can change this to DoReFaQuantizer to implement it DoReFaQuantizer(configure_list).compress(model) ''' diff --git a/examples/model_compress/quantization/mixed_precision_speedup_mnist.py b/examples/model_compress/quantization/mixed_precision_speedup_mnist.py index bdcdcb7f5f..687fec6a1f 100644 --- a/examples/model_compress/quantization/mixed_precision_speedup_mnist.py +++ b/examples/model_compress/quantization/mixed_precision_speedup_mnist.py @@ -5,28 +5,9 @@ from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT -class Mnist(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) - self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) - self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) - self.fc2 = torch.nn.Linear(500, 10) - self.relu1 = torch.nn.ReLU6() - self.relu2 = torch.nn.ReLU6() - self.relu3 = torch.nn.ReLU6() - self.max_pool1 = torch.nn.MaxPool2d(2, 2) - self.max_pool2 = torch.nn.MaxPool2d(2, 2) - - def forward(self, x): - x = self.relu1(self.conv1(x)) - x = self.max_pool1(x) - x = self.relu2(self.conv2(x)) - x = self.max_pool2(x) - x = x.view(-1, 4 * 4 * 50) - x = self.relu3(self.fc1(x)) - x = self.fc2(x) - return F.log_softmax(x, dim=1) +import sys +sys.path.append('../models') +from mnist.naive import NaiveModel def train(model, device, train_loader, optimizer): @@ -74,7 +55,7 @@ def test_trt(engine, test_loader): print("Inference elapsed_time (whole dataset): {}s".format(time_elasped)) def post_training_quantization_example(train_loader, test_loader, device): - model = Mnist() + model = NaiveModel() config = { 'conv1':{'weight_bit':8, 'activation_bit':8}, @@ -99,7 +80,7 @@ def post_training_quantization_example(train_loader, test_loader, device): test_trt(engine, test_loader) def quantization_aware_training_example(train_loader, test_loader, device): - model = Mnist() + model = NaiveModel() configure_list = [{ 'quant_types': ['weight', 'output'], diff --git a/nni/algorithms/compression/pytorch/pruning/__init__.py b/nni/algorithms/compression/pytorch/pruning/__init__.py index f534b25da0..f49cf0cb65 100644 --- a/nni/algorithms/compression/pytorch/pruning/__init__.py +++ b/nni/algorithms/compression/pytorch/pruning/__init__.py @@ -1,14 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .finegrained_pruning import * -from .structured_pruning import * -from .one_shot import * -from .agp import * +from .finegrained_pruning_masker import * +from .structured_pruning_masker import * +from .one_shot_pruner import * +from .iterative_pruner import * from .lottery_ticket import LotteryTicketPruner from .simulated_annealing_pruner import SimulatedAnnealingPruner from .net_adapt_pruner import NetAdaptPruner -from .admm_pruner import ADMMPruner from .auto_compress_pruner import AutoCompressPruner from .sensitivity_pruner import SensitivityPruner from .amc import AMCPruner diff --git a/nni/algorithms/compression/pytorch/pruning/admm_pruner.py b/nni/algorithms/compression/pytorch/pruning/admm_pruner.py deleted file mode 100644 index 30e73a23f8..0000000000 --- a/nni/algorithms/compression/pytorch/pruning/admm_pruner.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import logging -import torch -from schema import And, Optional -import copy - -from nni.compression.pytorch.utils.config_validation import CompressorSchema -from .constants import MASKER_DICT -from .one_shot import OneshotPruner - - -_logger = logging.getLogger(__name__) - - -class ADMMPruner(OneshotPruner): - """ - A Pytorch implementation of ADMM Pruner algorithm. - - Parameters - ---------- - model : torch.nn.Module - Model to be pruned. - config_list : list - List on pruning configs. - trainer : function - Function used for the first subproblem. - Users should write this function as a normal function to train the Pytorch model - and include `model, optimizer, criterion, epoch, callback` as function arguments. - Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper. - The logic of `callback` is implemented inside the Pruner, - users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`. - Example:: - - def trainer(model, criterion, optimizer, epoch, callback): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - train_loader = ... - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = criterion(output, target) - loss.backward() - # callback should be inserted between loss.backward() and optimizer.step() - if callback: - callback() - optimizer.step() - num_iterations : int - Total number of iterations. - training_epochs : int - Training epochs of the first subproblem. - row : float - Penalty parameters for ADMM training. - base_algo : str - Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops, - the assigned `base_algo` is used to decide which filters/channels/weights to prune. - - """ - - def __init__(self, model, config_list, trainer, num_iterations=30, training_epochs=5, row=1e-4, base_algo='l1'): - self._base_algo = base_algo - - super().__init__(model, config_list) - - self._trainer = trainer - self._num_iterations = num_iterations - self._training_epochs = training_epochs - self._row = row - - self.set_wrappers_attribute("if_calculated", False) - self.masker = MASKER_DICT[self._base_algo](self.bound_model, self) - - def validate_config(self, model, config_list): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - List on pruning configs - """ - - if self._base_algo == 'level': - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), - Optional('op_types'): [str], - Optional('op_names'): [str], - }], model, _logger) - elif self._base_algo in ['l1', 'l2', 'fpgm']: - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), - 'op_types': ['Conv2d'], - Optional('op_names'): [str] - }], model, _logger) - - schema.validate(config_list) - - def _projection(self, weight, sparsity, wrapper): - ''' - Return the Euclidean projection of the weight matrix according to the pruning mode. - - Parameters - ---------- - weight : tensor - original matrix - sparsity : float - the ratio of parameters which need to be set to zero - wrapper: PrunerModuleWrapper - layer wrapper of this layer - - Returns - ------- - tensor - the projected matrix - ''' - wrapper_copy = copy.deepcopy(wrapper) - wrapper_copy.module.weight.data = weight - return weight.data.mul(self.masker.calc_mask(sparsity, wrapper_copy)['weight_mask']) - - def compress(self): - """ - Compress the model with ADMM. - - Returns - ------- - torch.nn.Module - model with specified modules compressed. - """ - _logger.info('Starting ADMM Compression...') - - # initiaze Z, U - # Z_i^0 = W_i^0 - # U_i^0 = 0 - Z = [] - U = [] - for wrapper in self.get_modules_wrapper(): - z = wrapper.module.weight.data - Z.append(z) - U.append(torch.zeros_like(z)) - - optimizer = torch.optim.Adam( - self.bound_model.parameters(), lr=1e-3, weight_decay=5e-5) - - # Loss = cross_entropy + l2 regulization + \Sum_{i=1}^N \row_i ||W_i - Z_i^k + U_i^k||^2 - criterion = torch.nn.CrossEntropyLoss() - - # callback function to do additonal optimization, refer to the deriatives of Formula (7) - def callback(): - for i, wrapper in enumerate(self.get_modules_wrapper()): - wrapper.module.weight.data -= self._row * \ - (wrapper.module.weight.data - Z[i] + U[i]) - - # optimization iteration - for k in range(self._num_iterations): - _logger.info('ADMM iteration : %d', k) - - # step 1: optimize W with AdamOptimizer - for epoch in range(self._training_epochs): - self._trainer(self.bound_model, optimizer=optimizer, - criterion=criterion, epoch=epoch, callback=callback) - - # step 2: update Z, U - # Z_i^{k+1} = projection(W_i^{k+1} + U_i^k) - # U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1} - for i, wrapper in enumerate(self.get_modules_wrapper()): - z = wrapper.module.weight.data + U[i] - Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper) - U[i] = U[i] + wrapper.module.weight.data - Z[i] - - # apply prune - self.update_mask() - - _logger.info('Compression finished.') - - return self.bound_model diff --git a/nni/algorithms/compression/pytorch/pruning/agp.py b/nni/algorithms/compression/pytorch/pruning/agp.py deleted file mode 100644 index ef9ca71635..0000000000 --- a/nni/algorithms/compression/pytorch/pruning/agp.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -An automated gradual pruning algorithm that prunes the smallest magnitude -weights to achieve a preset level of network sparsity. -Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the -efficacy of pruning for model compression", 2017 NIPS Workshop on Machine -Learning of Phones and other Consumer Devices. -""" - -import logging -import torch -from schema import And, Optional -from .constants import MASKER_DICT -from nni.compression.pytorch.utils.config_validation import CompressorSchema -from nni.compression.pytorch.compressor import Pruner - -__all__ = ['AGPPruner'] - -logger = logging.getLogger('torch pruner') - -class AGPPruner(Pruner): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned. - config_list : listlist - Supported keys: - - initial_sparsity: This is to specify the sparsity when compressor starts to compress. - - final_sparsity: This is to specify the sparsity when compressor finishes to compress. - - start_epoch: This is to specify the epoch number when compressor starts to compress, default start from epoch 0. - - end_epoch: This is to specify the epoch number when compressor finishes to compress. - - frequency: This is to specify every *frequency* number epochs compressor compress once, default frequency=1. - optimizer: torch.optim.Optimizer - Optimizer used to train model. - pruning_algorithm: str - Algorithms being used to prune model, - choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level` - """ - - def __init__(self, model, config_list, optimizer, pruning_algorithm='level'): - super().__init__(model, config_list, optimizer) - assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it" - self.masker = MASKER_DICT[pruning_algorithm](model, self) - - self.now_epoch = 0 - self.set_wrappers_attribute("if_calculated", False) - - def validate_config(self, model, config_list): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - List on pruning configs - """ - schema = CompressorSchema([{ - 'initial_sparsity': And(float, lambda n: 0 <= n <= 1), - 'final_sparsity': And(float, lambda n: 0 <= n <= 1), - 'start_epoch': And(int, lambda n: n >= 0), - 'end_epoch': And(int, lambda n: n >= 0), - 'frequency': And(int, lambda n: n > 0), - Optional('op_types'): [str], - Optional('op_names'): [str] - }], model, logger) - - schema.validate(config_list) - - def calc_mask(self, wrapper, wrapper_idx=None): - """ - Calculate the mask of given layer. - Scale factors with the smallest absolute value in the BN layer are masked. - Parameters - ---------- - wrapper : Module - the layer to instrument the compression operation - wrapper_idx: int - index of this wrapper in pruner's all wrappers - Returns - ------- - dict | None - Dictionary for storing masks, keys of the dict: - 'weight_mask': weight mask tensor - 'bias_mask': bias mask tensor (optional) - """ - - config = wrapper.config - - start_epoch = config.get('start_epoch', 0) - freq = config.get('frequency', 1) - - if wrapper.if_calculated: - return None - if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0): - return None - - target_sparsity = self.compute_target_sparsity(config) - new_mask = self.masker.calc_mask(sparsity=target_sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx) - if new_mask is not None: - wrapper.if_calculated = True - - return new_mask - - def compute_target_sparsity(self, config): - """ - Calculate the sparsity for pruning - Parameters - ---------- - config : dict - Layer's pruning config - Returns - ------- - float - Target sparsity to be pruned - """ - - end_epoch = config.get('end_epoch', 1) - start_epoch = config.get('start_epoch', 0) - freq = config.get('frequency', 1) - final_sparsity = config.get('final_sparsity', 0) - initial_sparsity = config.get('initial_sparsity', 0) - if end_epoch <= start_epoch or initial_sparsity >= final_sparsity: - logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity') - return final_sparsity - - if end_epoch <= self.now_epoch: - return final_sparsity - - span = ((end_epoch - start_epoch - 1) // freq) * freq - assert span > 0 - target_sparsity = (final_sparsity + - (initial_sparsity - final_sparsity) * - (1.0 - ((self.now_epoch - start_epoch) / span)) ** 3) - return target_sparsity - - def update_epoch(self, epoch): - """ - Update epoch - Parameters - ---------- - epoch : int - current training epoch - """ - - if epoch > 0: - self.now_epoch = epoch - for wrapper in self.get_modules_wrapper(): - wrapper.if_calculated = False diff --git a/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py b/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py index fdc27ac2f4..82a8f1cb98 100644 --- a/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py @@ -13,8 +13,7 @@ from nni.compression.pytorch.compressor import Pruner from nni.compression.pytorch.utils.config_validation import CompressorSchema from .simulated_annealing_pruner import SimulatedAnnealingPruner -from .admm_pruner import ADMMPruner - +from .iterative_pruner import ADMMPruner _logger = logging.getLogger(__name__) @@ -34,26 +33,7 @@ class AutoCompressPruner(Pruner): trainer : function Function used for the first subproblem of ADMM Pruner. Users should write this function as a normal function to train the Pytorch model - and include `model, optimizer, criterion, epoch, callback` as function arguments. - Here `callback` acts as an L2 regulizer as presented in the formula (7) of the original paper. - The logic of `callback` is implemented inside the Pruner, - users are just required to insert `callback()` between `loss.backward()` and `optimizer.step()`. - Example:: - - def trainer(model, criterion, optimizer, epoch, callback): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - train_loader = ... - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = criterion(output, target) - loss.backward() - # callback should be inserted between loss.backward() and optimizer.step() - if callback: - callback() - optimizer.step() + and include `model, optimizer, criterion, epoch` as function arguments. evaluator : function function to evaluate the pruned model. This function should include `model` as the only parameter, and returns a scalar value. @@ -80,8 +60,8 @@ def evaluator(model): optimize_mode : str optimize mode, `maximize` or `minimize`, by default `maximize`. base_algo : str - Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops, - the assigned `base_algo` is used to decide which filters/channels/weights to prune. + Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among + the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune. start_temperature : float Start temperature of the simulated annealing process. stop_temperature : float @@ -92,7 +72,7 @@ def evaluator(model): Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature. admm_num_iterations : int Number of iterations of ADMM Pruner. - admm_training_epochs : int + admm_epochs_per_iteration : int Training epochs of the first optimization subproblem of ADMMPruner. row : float Penalty parameters for ADMM training. @@ -100,18 +80,19 @@ def evaluator(model): PATH to store temporary experiment data. """ - def __init__(self, model, config_list, trainer, evaluator, dummy_input, + def __init__(self, model, config_list, trainer, criterion, evaluator, dummy_input, num_iterations=3, optimize_mode='maximize', base_algo='l1', # SimulatedAnnealing related start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35, # ADMM related - admm_num_iterations=30, admm_training_epochs=5, row=1e-4, + admm_num_iterations=30, admm_epochs_per_iteration=5, row=1e-4, experiment_data_dir='./'): # original model self._model_to_prune = model self._base_algo = base_algo self._trainer = trainer + self._criterion = criterion self._evaluator = evaluator self._dummy_input = dummy_input self._num_iterations = num_iterations @@ -125,7 +106,7 @@ def __init__(self, model, config_list, trainer, evaluator, dummy_input, # hyper parameters for ADMM algorithm self._admm_num_iterations = admm_num_iterations - self._admm_training_epochs = admm_training_epochs + self._admm_epochs_per_iteration = admm_epochs_per_iteration self._row = row # overall pruning rate @@ -174,12 +155,12 @@ def compress(self): """ _logger.info('Starting AutoCompress pruning...') - sparsity_each_round = 1 - pow(1-self._sparsity, 1/self._num_iterations) + sparsity_each_round = 1 - pow(1 - self._sparsity, 1 / self._num_iterations) for i in range(self._num_iterations): _logger.info('Pruning iteration: %d', i) _logger.info('Target sparsity this round: %s', - 1-pow(1-sparsity_each_round, i+1)) + 1 - pow(1 - sparsity_each_round, i + 1)) # SimulatedAnnealingPruner _logger.info( @@ -204,9 +185,10 @@ def compress(self): ADMMpruner = ADMMPruner( model=copy.deepcopy(self._model_to_prune), config_list=config_list, + criterion=self._criterion, trainer=self._trainer, num_iterations=self._admm_num_iterations, - training_epochs=self._admm_training_epochs, + epochs_per_iteration=self._admm_epochs_per_iteration, row=self._row, base_algo=self._base_algo) ADMMpruner.compress() @@ -214,12 +196,13 @@ def compress(self): ADMMpruner.export_model(os.path.join(self._experiment_data_dir, 'model_admm_masked.pth'), os.path.join( self._experiment_data_dir, 'mask.pth')) - # use speed up to prune the model before next iteration, because SimulatedAnnealingPruner & ADMMPruner don't take masked models + # use speed up to prune the model before next iteration, + # because SimulatedAnnealingPruner & ADMMPruner don't take masked models self._model_to_prune.load_state_dict(torch.load(os.path.join( self._experiment_data_dir, 'model_admm_masked.pth'))) masks_file = os.path.join(self._experiment_data_dir, 'mask.pth') - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = next(self._model_to_prune.parameters()).device _logger.info('Speeding up models...') m_speedup = ModelSpeedup(self._model_to_prune, self._dummy_input, masks_file, device) diff --git a/nni/algorithms/compression/pytorch/pruning/constants_pruner.py b/nni/algorithms/compression/pytorch/pruning/constants_pruner.py index b0ad5cce37..55ba9506f3 100644 --- a/nni/algorithms/compression/pytorch/pruning/constants_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/constants_pruner.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. -from .one_shot import LevelPruner, L1FilterPruner, L2FilterPruner, FPGMPruner +from .one_shot_pruner import LevelPruner, L1FilterPruner, L2FilterPruner, FPGMPruner PRUNER_DICT = { 'level': LevelPruner, diff --git a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py new file mode 100644 index 0000000000..c0ca053a7d --- /dev/null +++ b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from schema import And, Optional, SchemaError +from nni.common.graph_utils import TorchModuleGraph +from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency +from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.compressor import Pruner +from .constants import MASKER_DICT + +__all__ = ['DependencyAwarePruner'] + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class DependencyAwarePruner(Pruner): + """ + DependencyAwarePruner has two ways to calculate the masks + for conv layers. In the normal way, the DependencyAwarePruner + will calculate the mask of each layer separately. For example, each + conv layer determine which filters should be pruned according to its L1 + norm. In constrast, in the dependency-aware way, the layers that in a + dependency group will be pruned jointly and these layers will be forced + to prune the same channels. + """ + + def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False, + dummy_input=None, **algo_kwargs): + super().__init__(model, config_list=config_list, optimizer=optimizer) + + self.dependency_aware = dependency_aware + self.dummy_input = dummy_input + + if self.dependency_aware: + if not self._supported_dependency_aware(): + raise ValueError('This pruner does not support dependency aware!') + + errmsg = "When dependency_aware is set, the dummy_input should not be None" + assert self.dummy_input is not None, errmsg + # Get the TorchModuleGraph of the target model + # to trace the model, we need to unwrap the wrappers + self._unwrap_model() + self.graph = TorchModuleGraph(model, dummy_input) + self._wrap_model() + self.channel_depen = ChannelDependency( + traced_model=self.graph.trace) + self.group_depen = GroupDependency(traced_model=self.graph.trace) + self.channel_depen = self.channel_depen.dependency_sets + self.channel_depen = { + name: sets for sets in self.channel_depen for name in sets} + self.group_depen = self.group_depen.dependency_sets + + self.masker = MASKER_DICT[pruning_algorithm]( + model, self, **algo_kwargs) + # set the dependency-aware switch for the masker + self.masker.dependency_aware = dependency_aware + self.set_wrappers_attribute("if_calculated", False) + + def calc_mask(self, wrapper, wrapper_idx=None): + if not wrapper.if_calculated: + sparsity = wrapper.config['sparsity'] + masks = self.masker.calc_mask( + sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx) + + # masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later + if masks is not None: + wrapper.if_calculated = True + return masks + else: + return None + + def update_mask(self): + if not self.dependency_aware: + # if we use the normal way to update the mask, + # then call the update_mask of the father class + super(DependencyAwarePruner, self).update_mask() + else: + # if we update the mask in a dependency-aware way + # then we call _dependency_update_mask + self._dependency_update_mask() + + def validate_config(self, model, config_list): + schema = CompressorSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), + Optional('op_types'): ['Conv2d'], + Optional('op_names'): [str], + Optional('exclude'): bool + }], model, logger) + + schema.validate(config_list) + for config in config_list: + if 'exclude' not in config and 'sparsity' not in config: + raise SchemaError('Either sparisty or exclude must be specified!') + + def _supported_dependency_aware(self): + raise NotImplementedError + + def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None): + """ + calculate the masks for the conv layers in the same + channel dependecy set. All the layers passed in have + the same number of channels. + + Parameters + ---------- + wrappers: list + The list of the wrappers that in the same channel dependency + set. + wrappers_idx: list + The list of the indexes of wrapppers. + Returns + ------- + masks: dict + A dict object that contains the masks of the layers in this + dependency group, the key is the name of the convolutional layers. + """ + # The number of the groups for each conv layers + # Note that, this number may be different from its + # original number of groups of filters. + groups = [self.group_depen[_w.name] for _w in wrappers] + sparsities = [_w.config['sparsity'] for _w in wrappers] + masks = self.masker.calc_mask( + sparsities, wrappers, wrappers_idx, channel_dsets=channel_dsets, groups=groups) + if masks is not None: + # if masks is None, then the mask calculation fails. + # for example, in activation related maskers, we should + # pass enough batches of data to the model, so that the + # masks can be calculated successfully. + for _w in wrappers: + _w.if_calculated = True + return masks + + def _dependency_update_mask(self): + """ + In the original update_mask, the wraper of each layer will update its + own mask according to the sparsity specified in the config_list. However, in + the _dependency_update_mask, we may prune several layers at the same + time according the sparsities and the channel/group dependencies. + """ + name2wrapper = {x.name: x for x in self.get_modules_wrapper()} + wrapper2index = {x: i for i, x in enumerate(self.get_modules_wrapper())} + for wrapper in self.get_modules_wrapper(): + if wrapper.if_calculated: + continue + # find all the conv layers that have channel dependecy with this layer + # and prune all these layers at the same time. + _names = [x for x in self.channel_depen[wrapper.name]] + logger.info('Pruning the dependent layers: %s', ','.join(_names)) + _wrappers = [name2wrapper[name] + for name in _names if name in name2wrapper] + _wrapper_idxes = [wrapper2index[_w] for _w in _wrappers] + + masks = self._dependency_calc_mask( + _wrappers, _names, wrappers_idx=_wrapper_idxes) + if masks is not None: + for layer in masks: + for mask_type in masks[layer]: + assert hasattr(name2wrapper[layer], mask_type), "there is no attribute '%s' in wrapper on %s" \ + % (mask_type, layer) + setattr(name2wrapper[layer], mask_type, masks[layer][mask_type]) diff --git a/nni/algorithms/compression/pytorch/pruning/finegrained_pruning.py b/nni/algorithms/compression/pytorch/pruning/finegrained_pruning_masker.py similarity index 100% rename from nni/algorithms/compression/pytorch/pruning/finegrained_pruning.py rename to nni/algorithms/compression/pytorch/pruning/finegrained_pruning_masker.py diff --git a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py new file mode 100644 index 0000000000..9651e9e35a --- /dev/null +++ b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py @@ -0,0 +1,576 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import copy +import torch +from schema import And, Optional +from nni.compression.pytorch.utils.config_validation import CompressorSchema +from .constants import MASKER_DICT +from .dependency_aware_pruner import DependencyAwarePruner + +__all__ = ['AGPPruner', 'ADMMPruner', 'SlimPruner', 'TaylorFOWeightFilterPruner', 'ActivationAPoZRankFilterPruner', + 'ActivationMeanRankFilterPruner'] + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class IterativePruner(DependencyAwarePruner): + """ + Prune model during the training process. + """ + + def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', trainer=None, criterion=None, + num_iterations=20, epochs_per_iteration=5, dependency_aware=False, dummy_input=None, **algo_kwargs): + """ + Parameters + ---------- + model: torch.nn.Module + Model to be pruned + config_list: list + List on pruning configs + optimizer: torch.optim.Optimizer + Optimizer used to train model + pruning_algorithm: str + algorithms being used to prune model + trainer: function + Function used to train the model. + Users should write this function as a normal function to train the Pytorch model + and include `model, optimizer, criterion, epoch` as function arguments. + criterion: function + Function used to calculate the loss between the target and the output. + num_iterations: int + Total number of iterations in pruning process. We will calculate mask at the end of an iteration. + epochs_per_iteration: Union[int, list] + The number of training epochs for each iteration. `int` represents the same value for each iteration. + `list` represents the specific value for each iteration. + dependency_aware: bool + If prune the model in a dependency-aware way. + dummy_input: torch.Tensor + The dummy input to analyze the topology constraints. Note that, + the dummy_input should on the same device with the model. + algo_kwargs: dict + Additional parameters passed to pruning algorithm masker class + """ + super().__init__(model, config_list, optimizer, pruning_algorithm, dependency_aware, dummy_input, **algo_kwargs) + + if isinstance(epochs_per_iteration, list): + assert len(epochs_per_iteration) == num_iterations, 'num_iterations should equal to the length of epochs_per_iteration' + self.epochs_per_iteration = epochs_per_iteration + else: + self.epochs_per_iteration = [epochs_per_iteration] * num_iterations + + self._trainer = trainer + self._criterion = criterion + + def _fresh_calculated(self): + for wrapper in self.get_modules_wrapper(): + wrapper.if_calculated = False + + def compress(self): + training = self.bound_model.training + self.bound_model.train() + for _, epochs_num in enumerate(self.epochs_per_iteration): + self._fresh_calculated() + for epoch in range(epochs_num): + self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch) + self.update_mask() + self.bound_model.train(training) + + return self.bound_model + + +class AGPPruner(IterativePruner): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned. + config_list : listlist + Supported keys: + - sparsity : This is to specify the sparsity operations to be compressed to. + - op_types : See supported type in your specific pruning algorithm. + optimizer: torch.optim.Optimizer + Optimizer used to train model. + trainer: function + Function to train the model + criterion: function + Function used to calculate the loss between the target and the output. + num_iterations: int + Total number of iterations in pruning process. We will calculate mask at the end of an iteration. + epochs_per_iteration: int + The number of training epochs for each iteration. + pruning_algorithm: str + Algorithms being used to prune model, + choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level` + """ + + def __init__(self, model, config_list, optimizer, trainer, criterion, num_iterations=10, epochs_per_iteration=1, pruning_algorithm='level'): + super().__init__(model, config_list, optimizer=optimizer, trainer=trainer, criterion=criterion, + num_iterations=num_iterations, epochs_per_iteration=epochs_per_iteration) + assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it" + self.masker = MASKER_DICT[pruning_algorithm](model, self) + self.now_epoch = 0 + self.freq = epochs_per_iteration + self.end_epoch = epochs_per_iteration * num_iterations + self.set_wrappers_attribute("if_calculated", False) + + def validate_config(self, model, config_list): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + List on pruning configs + """ + schema = CompressorSchema([{ + 'sparsity': And(float, lambda n: 0 <= n <= 1), + Optional('op_types'): [str], + Optional('op_names'): [str] + }], model, logger) + + schema.validate(config_list) + + def _supported_dependency_aware(self): + return False + + def calc_mask(self, wrapper, wrapper_idx=None): + """ + Calculate the mask of given layer. + Scale factors with the smallest absolute value in the BN layer are masked. + Parameters + ---------- + wrapper : Module + the layer to instrument the compression operation + wrapper_idx: int + index of this wrapper in pruner's all wrappers + Returns + ------- + dict | None + Dictionary for storing masks, keys of the dict: + 'weight_mask': weight mask tensor + 'bias_mask': bias mask tensor (optional) + """ + + config = wrapper.config + + if wrapper.if_calculated: + return None + + if not self.now_epoch % self.freq == 0: + return None + + target_sparsity = self.compute_target_sparsity(config) + new_mask = self.masker.calc_mask(sparsity=target_sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx) + + if new_mask is not None: + wrapper.if_calculated = True + + return new_mask + + def compute_target_sparsity(self, config): + """ + Calculate the sparsity for pruning + Parameters + ---------- + config : dict + Layer's pruning config + Returns + ------- + float + Target sparsity to be pruned + """ + + initial_sparsity = 0 + self.target_sparsity = final_sparsity = config.get('sparsity', 0) + + if initial_sparsity >= final_sparsity: + logger.warning('your initial_sparsity >= final_sparsity') + return final_sparsity + + if self.end_epoch == 1 or self.end_epoch <= self.now_epoch: + return final_sparsity + + span = ((self.end_epoch - 1) // self.freq) * self.freq + assert span > 0 + self.target_sparsity = (final_sparsity + (initial_sparsity - final_sparsity) * (1.0 - (self.now_epoch / span)) ** 3) + return self.target_sparsity + + def update_epoch(self, epoch): + """ + Update epoch + Parameters + ---------- + epoch : int + current training epoch + """ + + if epoch > 0: + self.now_epoch = epoch + for wrapper in self.get_modules_wrapper(): + wrapper.if_calculated = False + + # TODO: need refactor + def compress(self): + training = self.bound_model.training + self.bound_model.train() + + for epoch in range(self.end_epoch): + self.update_epoch(epoch) + self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch) + self.update_mask() + logger.info(f'sparsity is {self.target_sparsity:.2f} at epoch {epoch}') + self.get_pruned_weights() + + self.bound_model.train(training) + + return self.bound_model + + +class ADMMPruner(IterativePruner): + """ + A Pytorch implementation of ADMM Pruner algorithm. + + Parameters + ---------- + model : torch.nn.Module + Model to be pruned. + config_list : list + List on pruning configs. + trainer : function + Function used for the first subproblem. + Users should write this function as a normal function to train the Pytorch model + and include `model, optimizer, criterion, epoch` as function arguments. + criterion: function + Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss in ADMMPruner. + num_iterations: int + Total number of iterations in pruning process. We will calculate mask after we finish all iterations in ADMMPruner. + epochs_per_iteration: int + Training epochs of the first subproblem. + row : float + Penalty parameters for ADMM training. + base_algo : str + Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among + the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune. + + """ + + def __init__(self, model, config_list, trainer, criterion=torch.nn.CrossEntropyLoss(), + num_iterations=30, epochs_per_iteration=5, row=1e-4, base_algo='l1'): + self._base_algo = base_algo + + super().__init__(model, config_list) + + self._trainer = trainer + self.optimizer = torch.optim.Adam( + self.bound_model.parameters(), lr=1e-3, weight_decay=5e-5) + self._criterion = criterion + self._num_iterations = num_iterations + self._training_epochs = epochs_per_iteration + self._row = row + + self.set_wrappers_attribute("if_calculated", False) + self.masker = MASKER_DICT[self._base_algo](self.bound_model, self) + + self.patch_optimizer_before(self._callback) + + def validate_config(self, model, config_list): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + List on pruning configs + """ + + if self._base_algo == 'level': + schema = CompressorSchema([{ + 'sparsity': And(float, lambda n: 0 < n < 1), + Optional('op_types'): [str], + Optional('op_names'): [str], + }], model, logger) + elif self._base_algo in ['l1', 'l2', 'fpgm']: + schema = CompressorSchema([{ + 'sparsity': And(float, lambda n: 0 < n < 1), + 'op_types': ['Conv2d'], + Optional('op_names'): [str] + }], model, logger) + + schema.validate(config_list) + + def _supported_dependency_aware(self): + return False + + def _projection(self, weight, sparsity, wrapper): + ''' + Return the Euclidean projection of the weight matrix according to the pruning mode. + + Parameters + ---------- + weight : tensor + original matrix + sparsity : float + the ratio of parameters which need to be set to zero + wrapper: PrunerModuleWrapper + layer wrapper of this layer + + Returns + ------- + tensor + the projected matrix + ''' + wrapper_copy = copy.deepcopy(wrapper) + wrapper_copy.module.weight.data = weight + return weight.data.mul(self.masker.calc_mask(sparsity, wrapper_copy)['weight_mask']) + + def _callback(self): + # callback function to do additonal optimization, refer to the deriatives of Formula (7) + for i, wrapper in enumerate(self.get_modules_wrapper()): + wrapper.module.weight.data -= self._row * \ + (wrapper.module.weight.data - self.Z[i] + self.U[i]) + + def compress(self): + """ + Compress the model with ADMM. + + Returns + ------- + torch.nn.Module + model with specified modules compressed. + """ + logger.info('Starting ADMM Compression...') + + # initiaze Z, U + # Z_i^0 = W_i^0 + # U_i^0 = 0 + self.Z = [] + self.U = [] + for wrapper in self.get_modules_wrapper(): + z = wrapper.module.weight.data + self.Z.append(z) + self.U.append(torch.zeros_like(z)) + + # Loss = cross_entropy + l2 regulization + \Sum_{i=1}^N \row_i ||W_i - Z_i^k + U_i^k||^2 + # optimization iteration + for k in range(self._num_iterations): + logger.info('ADMM iteration : %d', k) + + # step 1: optimize W with AdamOptimizer + for epoch in range(self._training_epochs): + self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch) + + # step 2: update Z, U + # Z_i^{k+1} = projection(W_i^{k+1} + U_i^k) + # U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1} + for i, wrapper in enumerate(self.get_modules_wrapper()): + z = wrapper.module.weight.data + self.U[i] + self.Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper) + self.U[i] = self.U[i] + wrapper.module.weight.data - self.Z[i] + + # apply prune + self.update_mask() + + logger.info('Compression finished.') + + return self.bound_model + + +class SlimPruner(IterativePruner): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + Supported keys: + - sparsity : This is to specify the sparsity operations to be compressed to. + - op_types : Only BatchNorm2d is supported in Slim Pruner. + optimizer : torch.optim.Optimizer + Optimizer used to train model + trainer : function + Function used to sparsify BatchNorm2d scaling factors. + Users should write this function as a normal function to train the Pytorch model + and include `model, optimizer, criterion, epoch` as function arguments. + criterion : function + Function used to calculate the loss between the target and the output. + sparsity_training_epochs: int + The number of channel sparsity regularization training epochs before pruning. + scale : float + Penalty parameters for sparsification. + dependency_aware: bool + If prune the model in a dependency-aware way. If it is `True`, this pruner will + prune the model according to the l2-norm of weights and the channel-dependency or + group-dependency of the model. In this way, the pruner will force the conv layers + that have dependencies to prune the same channels, so the speedup module can better + harvest the speed benefit from the pruned model. Note that, if this flag is set True + , the dummy_input cannot be None, because the pruner needs a dummy input to trace the + dependency between the conv layers. + dummy_input : torch.Tensor + The dummy input to analyze the topology constraints. Note that, the dummy_input + should on the same device with the model. + """ + + def __init__(self, model, config_list, optimizer, trainer, criterion, sparsity_training_epochs=10, scale=0.0001, + dependency_aware=False, dummy_input=None): + super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='slim', trainer=trainer, criterion=criterion, + num_iterations=1, epochs_per_iteration=sparsity_training_epochs, dependency_aware=dependency_aware, + dummy_input=dummy_input) + self.scale = scale + self.patch_optimizer_before(self._callback) + + def validate_config(self, model, config_list): + schema = CompressorSchema([{ + 'sparsity': And(float, lambda n: 0 < n < 1), + 'op_types': ['BatchNorm2d'], + Optional('op_names'): [str] + }], model, logger) + + schema.validate(config_list) + + if len(config_list) > 1: + logger.warning('Slim pruner only supports 1 configuration') + + def _supported_dependency_aware(self): + return True + + def _callback(self): + for _, wrapper in enumerate(self.get_modules_wrapper()): + wrapper.module.weight.grad.data.add_(self.scale * torch.sign(wrapper.module.weight.data)) + + +class TaylorFOWeightFilterPruner(IterativePruner): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + Supported keys: + - sparsity : How much percentage of convolutional filters are to be pruned. + - op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model + trainer : function + Function used to sparsify BatchNorm2d scaling factors. + Users should write this function as a normal function to train the Pytorch model + and include `model, optimizer, criterion, epoch` as function arguments. + criterion : function + Function used to calculate the loss between the target and the output. + sparsity_training_epochs: int + The number of epochs to collect the contributions. + dependency_aware: bool + If prune the model in a dependency-aware way. If it is `True`, this pruner will + prune the model according to the l2-norm of weights and the channel-dependency or + group-dependency of the model. In this way, the pruner will force the conv layers + that have dependencies to prune the same channels, so the speedup module can better + harvest the speed benefit from the pruned model. Note that, if this flag is set True + , the dummy_input cannot be None, because the pruner needs a dummy input to trace the + dependency between the conv layers. + dummy_input : torch.Tensor + The dummy input to analyze the topology constraints. Note that, the dummy_input + should on the same device with the model. + + """ + + def __init__(self, model, config_list, optimizer, trainer, criterion, sparsity_training_epochs=1, dependency_aware=False, + dummy_input=None): + super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer, + criterion=criterion, num_iterations=1, epochs_per_iteration=sparsity_training_epochs, + dependency_aware=dependency_aware, dummy_input=dummy_input) + + def _supported_dependency_aware(self): + return True + + +class ActivationAPoZRankFilterPruner(IterativePruner): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + Supported keys: + - sparsity : How much percentage of convolutional filters are to be pruned. + - op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model + trainer: function + Function used to train the model. + Users should write this function as a normal function to train the Pytorch model + and include `model, optimizer, criterion, epoch` as function arguments. + criterion : function + Function used to calculate the loss between the target and the output. + activation: str + The activation type. + sparsity_training_epochs: int + The number of epochs to statistic the activation. + dependency_aware: bool + If prune the model in a dependency-aware way. If it is `True`, this pruner will + prune the model according to the l2-norm of weights and the channel-dependency or + group-dependency of the model. In this way, the pruner will force the conv layers + that have dependencies to prune the same channels, so the speedup module can better + harvest the speed benefit from the pruned model. Note that, if this flag is set True + , the dummy_input cannot be None, because the pruner needs a dummy input to trace the + dependency between the conv layers. + dummy_input : torch.Tensor + The dummy input to analyze the topology constraints. Note that, the dummy_input + should on the same device with the model. + + """ + + def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu', + sparsity_training_epochs=1, dependency_aware=False, dummy_input=None): + super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, trainer=trainer, + criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input, + activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs) + + def _supported_dependency_aware(self): + return True + + +class ActivationMeanRankFilterPruner(IterativePruner): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + Supported keys: + - sparsity : How much percentage of convolutional filters are to be pruned. + - op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner. + optimizer: torch.optim.Optimizer + Optimizer used to train model. + trainer: function + Function used to train the model. + Users should write this function as a normal function to train the Pytorch model + and include `model, optimizer, criterion, epoch` as function arguments. + criterion : function + Function used to calculate the loss between the target and the output. + activation: str + The activation type. + sparsity_training_epochs: int + The number of batches to statistic the activation. + dependency_aware: bool + If prune the model in a dependency-aware way. If it is `True`, this pruner will + prune the model according to the l2-norm of weights and the channel-dependency or + group-dependency of the model. In this way, the pruner will force the conv layers + that have dependencies to prune the same channels, so the speedup module can better + harvest the speed benefit from the pruned model. Note that, if this flag is set True + , the dummy_input cannot be None, because the pruner needs a dummy input to trace the + dependency between the conv layers. + dummy_input : torch.Tensor + The dummy input to analyze the topology constraints. Note that, the dummy_input + should on the same device with the model. + """ + + def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu', + sparsity_training_epochs=1, dependency_aware=False, dummy_input=None): + super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, trainer=trainer, + criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input, + activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs) + + def _supported_dependency_aware(self): + return True diff --git a/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py b/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py index b0d041dd02..caa1c831e6 100644 --- a/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py +++ b/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py @@ -7,7 +7,7 @@ from schema import And, Optional from nni.compression.pytorch.utils.config_validation import CompressorSchema from nni.compression.pytorch.compressor import Pruner -from .finegrained_pruning import LevelPrunerMasker +from .finegrained_pruning_masker import LevelPrunerMasker logger = logging.getLogger('torch pruner') diff --git a/nni/algorithms/compression/pytorch/pruning/one_shot.py b/nni/algorithms/compression/pytorch/pruning/one_shot.py deleted file mode 100644 index 75e2a7c307..0000000000 --- a/nni/algorithms/compression/pytorch/pruning/one_shot.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import logging -from schema import And, Optional, SchemaError -from nni.common.graph_utils import TorchModuleGraph -from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency -from .constants import MASKER_DICT -from nni.compression.pytorch.utils.config_validation import CompressorSchema -from nni.compression.pytorch.compressor import Pruner - - -__all__ = ['LevelPruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner', - 'TaylorFOWeightFilterPruner', 'ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner'] - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class OneshotPruner(Pruner): - """ - Prune model to an exact pruning level for one time. - """ - - def __init__(self, model, config_list, pruning_algorithm='level', optimizer=None, **algo_kwargs): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - List on pruning configs - pruning_algorithm: str - algorithms being used to prune model - optimizer: torch.optim.Optimizer - Optimizer used to train model - algo_kwargs: dict - Additional parameters passed to pruning algorithm masker class - """ - - super().__init__(model, config_list, optimizer) - self.set_wrappers_attribute("if_calculated", False) - self.masker = MASKER_DICT[pruning_algorithm]( - model, self, **algo_kwargs) - - def validate_config(self, model, config_list): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - List on pruning configs - """ - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), - Optional('op_types'): [str], - Optional('op_names'): [str] - }], model, logger) - - schema.validate(config_list) - - def calc_mask(self, wrapper, wrapper_idx=None): - """ - Calculate the mask of given layer - Parameters - ---------- - wrapper : Module - the module to instrument the compression operation - wrapper_idx: int - index of this wrapper in pruner's all wrappers - Returns - ------- - dict - dictionary for storing masks, keys of the dict: - 'weight_mask': weight mask tensor - 'bias_mask': bias mask tensor (optional) - """ - if wrapper.if_calculated: - return None - - sparsity = wrapper.config['sparsity'] - if not wrapper.if_calculated: - masks = self.masker.calc_mask( - sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx) - - # masker.calc_mask returns None means calc_mask is not calculated sucessfully, can try later - if masks is not None: - wrapper.if_calculated = True - return masks - else: - return None - - -class LevelPruner(OneshotPruner): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - Supported keys: - - sparsity : This is to specify the sparsity operations to be compressed to. - - op_types : Operation types to prune. - optimizer: torch.optim.Optimizer - Optimizer used to train model - """ - - def __init__(self, model, config_list, optimizer=None): - super().__init__(model, config_list, pruning_algorithm='level', optimizer=optimizer) - - -class SlimPruner(OneshotPruner): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - Supported keys: - - sparsity : This is to specify the sparsity operations to be compressed to. - - op_types : Only BatchNorm2d is supported in Slim Pruner. - optimizer: torch.optim.Optimizer - Optimizer used to train model - """ - - def __init__(self, model, config_list, optimizer=None): - super().__init__(model, config_list, pruning_algorithm='slim', optimizer=optimizer) - - def validate_config(self, model, config_list): - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), - 'op_types': ['BatchNorm2d'], - Optional('op_names'): [str] - }], model, logger) - - schema.validate(config_list) - - if len(config_list) > 1: - logger.warning('Slim pruner only supports 1 configuration') - - -class _StructuredFilterPruner(OneshotPruner): - """ - _StructuredFilterPruner has two ways to calculate the masks - for conv layers. In the normal way, the _StructuredFilterPruner - will calculate the mask of each layer separately. For example, each - conv layer determine which filters should be pruned according to its L1 - norm. In constrast, in the dependency-aware way, the layers that in a - dependency group will be pruned jointly and these layers will be forced - to prune the same channels. - """ - - def __init__(self, model, config_list, pruning_algorithm, optimizer=None, dependency_aware=False, dummy_input=None, **algo_kwargs): - super().__init__(model, config_list, pruning_algorithm=pruning_algorithm, - optimizer=optimizer, **algo_kwargs) - self.dependency_aware = dependency_aware - # set the dependency-aware switch for the masker - self.masker.dependency_aware = dependency_aware - self.dummy_input = dummy_input - if self.dependency_aware: - errmsg = "When dependency_aware is set, the dummy_input should not be None" - assert self.dummy_input is not None, errmsg - # Get the TorchModuleGraph of the target model - # to trace the model, we need to unwrap the wrappers - self._unwrap_model() - self.graph = TorchModuleGraph(model, dummy_input) - self._wrap_model() - self.channel_depen = ChannelDependency( - traced_model=self.graph.trace) - self.group_depen = GroupDependency(traced_model=self.graph.trace) - self.channel_depen = self.channel_depen.dependency_sets - self.channel_depen = { - name: sets for sets in self.channel_depen for name in sets} - self.group_depen = self.group_depen.dependency_sets - - def update_mask(self): - if not self.dependency_aware: - # if we use the normal way to update the mask, - # then call the update_mask of the father class - super(_StructuredFilterPruner, self).update_mask() - else: - # if we update the mask in a dependency-aware way - # then we call _dependency_update_mask - self._dependency_update_mask() - - def validate_config(self, model, config_list): - schema = CompressorSchema([{ - Optional('sparsity'): And(float, lambda n: 0 < n < 1), - Optional('op_types'): ['Conv2d'], - Optional('op_names'): [str], - Optional('exclude'): bool - }], model, logger) - - schema.validate(config_list) - for config in config_list: - if 'exclude' not in config and 'sparsity' not in config: - raise SchemaError('Either sparisty or exclude must be specified!') - - def _dependency_calc_mask(self, wrappers, channel_dsets, wrappers_idx=None): - """ - calculate the masks for the conv layers in the same - channel dependecy set. All the layers passed in have - the same number of channels. - - Parameters - ---------- - wrappers: list - The list of the wrappers that in the same channel dependency - set. - wrappers_idx: list - The list of the indexes of wrapppers. - Returns - ------- - masks: dict - A dict object that contains the masks of the layers in this - dependency group, the key is the name of the convolutional layers. - """ - # The number of the groups for each conv layers - # Note that, this number may be different from its - # original number of groups of filters. - groups = [self.group_depen[_w.name] for _w in wrappers] - sparsities = [_w.config['sparsity'] for _w in wrappers] - masks = self.masker.calc_mask( - sparsities, wrappers, wrappers_idx, channel_dsets=channel_dsets, groups=groups) - if masks is not None: - # if masks is None, then the mask calculation fails. - # for example, in activation related maskers, we should - # pass enough batches of data to the model, so that the - # masks can be calculated successfully. - for _w in wrappers: - _w.if_calculated = True - return masks - - def _dependency_update_mask(self): - """ - In the original update_mask, the wraper of each layer will update its - own mask according to the sparsity specified in the config_list. However, in - the _dependency_update_mask, we may prune several layers at the same - time according the sparsities and the channel/group dependencies. - """ - name2wrapper = {x.name: x for x in self.get_modules_wrapper()} - wrapper2index = {x: i for i, x in enumerate(self.get_modules_wrapper())} - for wrapper in self.get_modules_wrapper(): - if wrapper.if_calculated: - continue - # find all the conv layers that have channel dependecy with this layer - # and prune all these layers at the same time. - _names = [x for x in self.channel_depen[wrapper.name]] - logger.info('Pruning the dependent layers: %s', ','.join(_names)) - _wrappers = [name2wrapper[name] - for name in _names if name in name2wrapper] - _wrapper_idxes = [wrapper2index[_w] for _w in _wrappers] - - masks = self._dependency_calc_mask( - _wrappers, _names, wrappers_idx=_wrapper_idxes) - if masks is not None: - for layer in masks: - for mask_type in masks[layer]: - assert hasattr( - name2wrapper[layer], mask_type), "there is no attribute '%s' in wrapper on %s" % (mask_type, layer) - setattr(name2wrapper[layer], mask_type, masks[layer][mask_type]) - - -class L1FilterPruner(_StructuredFilterPruner): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - Supported keys: - - sparsity : This is to specify the sparsity operations to be compressed to. - - op_types : Only Conv2d is supported in L1FilterPruner. - optimizer: torch.optim.Optimizer - Optimizer used to train model - dependency_aware: bool - If prune the model in a dependency-aware way. If it is `True`, this pruner will - prune the model according to the l2-norm of weights and the channel-dependency or - group-dependency of the model. In this way, the pruner will force the conv layers - that have dependencies to prune the same channels, so the speedup module can better - harvest the speed benefit from the pruned model. Note that, if this flag is set True - , the dummy_input cannot be None, because the pruner needs a dummy input to trace the - dependency between the conv layers. - dummy_input : torch.Tensor - The dummy input to analyze the topology constraints. Note that, the dummy_input - should on the same device with the model. - """ - - def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None): - super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer, - dependency_aware=dependency_aware, dummy_input=dummy_input) - - -class L2FilterPruner(_StructuredFilterPruner): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - Supported keys: - - sparsity : This is to specify the sparsity operations to be compressed to. - - op_types : Only Conv2d is supported in L2FilterPruner. - optimizer: torch.optim.Optimizer - Optimizer used to train model - dependency_aware: bool - If prune the model in a dependency-aware way. If it is `True`, this pruner will - prune the model according to the l2-norm of weights and the channel-dependency or - group-dependency of the model. In this way, the pruner will force the conv layers - that have dependencies to prune the same channels, so the speedup module can better - harvest the speed benefit from the pruned model. Note that, if this flag is set True - , the dummy_input cannot be None, because the pruner needs a dummy input to trace the - dependency between the conv layers. - dummy_input : torch.Tensor - The dummy input to analyze the topology constraints. Note that, the dummy_input - should on the same device with the model. - """ - - def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None): - super().__init__(model, config_list, pruning_algorithm='l2', optimizer=optimizer, - dependency_aware=dependency_aware, dummy_input=dummy_input) - - -class FPGMPruner(_StructuredFilterPruner): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - Supported keys: - - sparsity : This is to specify the sparsity operations to be compressed to. - - op_types : Only Conv2d is supported in FPGM Pruner. - optimizer: torch.optim.Optimizer - Optimizer used to train model - dependency_aware: bool - If prune the model in a dependency-aware way. If it is `True`, this pruner will - prune the model according to the l2-norm of weights and the channel-dependency or - group-dependency of the model. In this way, the pruner will force the conv layers - that have dependencies to prune the same channels, so the speedup module can better - harvest the speed benefit from the pruned model. Note that, if this flag is set True - , the dummy_input cannot be None, because the pruner needs a dummy input to trace the - dependency between the conv layers. - dummy_input : torch.Tensor - The dummy input to analyze the topology constraints. Note that, the dummy_input - should on the same device with the model. - """ - - def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None): - super().__init__(model, config_list, pruning_algorithm='fpgm', - dependency_aware=dependency_aware, dummy_input=dummy_input, optimizer=optimizer) - - -class TaylorFOWeightFilterPruner(_StructuredFilterPruner): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - Supported keys: - - sparsity : How much percentage of convolutional filters are to be pruned. - - op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner. - optimizer: torch.optim.Optimizer - Optimizer used to train model - statistics_batch_num: int - The number of batches to statistic the activation. - dependency_aware: bool - If prune the model in a dependency-aware way. If it is `True`, this pruner will - prune the model according to the l2-norm of weights and the channel-dependency or - group-dependency of the model. In this way, the pruner will force the conv layers - that have dependencies to prune the same channels, so the speedup module can better - harvest the speed benefit from the pruned model. Note that, if this flag is set True - , the dummy_input cannot be None, because the pruner needs a dummy input to trace the - dependency between the conv layers. - dummy_input : torch.Tensor - The dummy input to analyze the topology constraints. Note that, the dummy_input - should on the same device with the model. - - """ - - def __init__(self, model, config_list, optimizer=None, statistics_batch_num=1, - dependency_aware=False, dummy_input=None): - super().__init__(model, config_list, pruning_algorithm='taylorfo', - dependency_aware=dependency_aware, dummy_input=dummy_input, - optimizer=optimizer, statistics_batch_num=statistics_batch_num) - - -class ActivationAPoZRankFilterPruner(_StructuredFilterPruner): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - Supported keys: - - sparsity : How much percentage of convolutional filters are to be pruned. - - op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner. - optimizer: torch.optim.Optimizer - Optimizer used to train model - activation: str - The activation type. - statistics_batch_num: int - The number of batches to statistic the activation. - dependency_aware: bool - If prune the model in a dependency-aware way. If it is `True`, this pruner will - prune the model according to the l2-norm of weights and the channel-dependency or - group-dependency of the model. In this way, the pruner will force the conv layers - that have dependencies to prune the same channels, so the speedup module can better - harvest the speed benefit from the pruned model. Note that, if this flag is set True - , the dummy_input cannot be None, because the pruner needs a dummy input to trace the - dependency between the conv layers. - dummy_input : torch.Tensor - The dummy input to analyze the topology constraints. Note that, the dummy_input - should on the same device with the model. - - """ - - def __init__(self, model, config_list, optimizer=None, activation='relu', - statistics_batch_num=1, dependency_aware=False, dummy_input=None): - super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, - dependency_aware=dependency_aware, dummy_input=dummy_input, - activation=activation, statistics_batch_num=statistics_batch_num) - - -class ActivationMeanRankFilterPruner(_StructuredFilterPruner): - """ - Parameters - ---------- - model : torch.nn.Module - Model to be pruned - config_list : list - Supported keys: - - sparsity : How much percentage of convolutional filters are to be pruned. - - op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner. - optimizer: torch.optim.Optimizer - Optimizer used to train model. - activation: str - The activation type. - statistics_batch_num: int - The number of batches to statistic the activation. - dependency_aware: bool - If prune the model in a dependency-aware way. If it is `True`, this pruner will - prune the model according to the l2-norm of weights and the channel-dependency or - group-dependency of the model. In this way, the pruner will force the conv layers - that have dependencies to prune the same channels, so the speedup module can better - harvest the speed benefit from the pruned model. Note that, if this flag is set True - , the dummy_input cannot be None, because the pruner needs a dummy input to trace the - dependency between the conv layers. - dummy_input : torch.Tensor - The dummy input to analyze the topology constraints. Note that, the dummy_input - should on the same device with the model. - """ - - def __init__(self, model, config_list, optimizer=None, activation='relu', - statistics_batch_num=1, dependency_aware=False, dummy_input=None): - super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, - dependency_aware=dependency_aware, dummy_input=dummy_input, - activation=activation, statistics_batch_num=statistics_batch_num) diff --git a/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py b/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py new file mode 100644 index 0000000000..c17a5ddafa --- /dev/null +++ b/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from schema import And, Optional + +from nni.compression.pytorch.utils.config_validation import CompressorSchema +from .dependency_aware_pruner import DependencyAwarePruner + +__all__ = ['LevelPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner'] + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class OneshotPruner(DependencyAwarePruner): + """ + Prune model to an exact pruning level for one time. + """ + + def __init__(self, model, config_list, pruning_algorithm='level', dependency_aware=False, dummy_input=None, + **algo_kwargs): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + List on pruning configs + pruning_algorithm: str + algorithms being used to prune model + dependency_aware: bool + If prune the model in a dependency-aware way. + dummy_input : torch.Tensor + The dummy input to analyze the topology constraints. Note that, + the dummy_input should on the same device with the model. + algo_kwargs: dict + Additional parameters passed to pruning algorithm masker class + """ + super().__init__(model, config_list, None, pruning_algorithm, dependency_aware, dummy_input, **algo_kwargs) + + def validate_config(self, model, config_list): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + List on pruning configs + """ + schema = CompressorSchema([{ + 'sparsity': And(float, lambda n: 0 < n < 1), + Optional('op_types'): [str], + Optional('op_names'): [str] + }], model, logger) + + schema.validate(config_list) + + +class LevelPruner(OneshotPruner): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + Supported keys: + - sparsity : This is to specify the sparsity operations to be compressed to. + - op_types : Operation types to prune. + """ + + def __init__(self, model, config_list): + super().__init__(model, config_list, pruning_algorithm='level') + + def _supported_dependency_aware(self): + return False + + +class L1FilterPruner(OneshotPruner): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + Supported keys: + - sparsity : This is to specify the sparsity operations to be compressed to. + - op_types : Only Conv2d is supported in L1FilterPruner. + dependency_aware: bool + If prune the model in a dependency-aware way. If it is `True`, this pruner will + prune the model according to the l2-norm of weights and the channel-dependency or + group-dependency of the model. In this way, the pruner will force the conv layers + that have dependencies to prune the same channels, so the speedup module can better + harvest the speed benefit from the pruned model. Note that, if this flag is set True + , the dummy_input cannot be None, because the pruner needs a dummy input to trace the + dependency between the conv layers. + dummy_input : torch.Tensor + The dummy input to analyze the topology constraints. Note that, the dummy_input + should on the same device with the model. + """ + + def __init__(self, model, config_list, dependency_aware=False, dummy_input=None): + super().__init__(model, config_list, pruning_algorithm='l1', dependency_aware=dependency_aware, + dummy_input=dummy_input) + + def _supported_dependency_aware(self): + return True + + +class L2FilterPruner(OneshotPruner): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + Supported keys: + - sparsity : This is to specify the sparsity operations to be compressed to. + - op_types : Only Conv2d is supported in L2FilterPruner. + dependency_aware: bool + If prune the model in a dependency-aware way. If it is `True`, this pruner will + prune the model according to the l2-norm of weights and the channel-dependency or + group-dependency of the model. In this way, the pruner will force the conv layers + that have dependencies to prune the same channels, so the speedup module can better + harvest the speed benefit from the pruned model. Note that, if this flag is set True + , the dummy_input cannot be None, because the pruner needs a dummy input to trace the + dependency between the conv layers. + dummy_input : torch.Tensor + The dummy input to analyze the topology constraints. Note that, the dummy_input + should on the same device with the model. + """ + + def __init__(self, model, config_list, dependency_aware=False, dummy_input=None): + super().__init__(model, config_list, pruning_algorithm='l2', dependency_aware=dependency_aware, + dummy_input=dummy_input) + + def _supported_dependency_aware(self): + return True + + +class FPGMPruner(OneshotPruner): + """ + Parameters + ---------- + model : torch.nn.Module + Model to be pruned + config_list : list + Supported keys: + - sparsity : This is to specify the sparsity operations to be compressed to. + - op_types : Only Conv2d is supported in FPGM Pruner. + dependency_aware: bool + If prune the model in a dependency-aware way. If it is `True`, this pruner will + prune the model according to the l2-norm of weights and the channel-dependency or + group-dependency of the model. In this way, the pruner will force the conv layers + that have dependencies to prune the same channels, so the speedup module can better + harvest the speed benefit from the pruned model. Note that, if this flag is set True + , the dummy_input cannot be None, because the pruner needs a dummy input to trace the + dependency between the conv layers. + dummy_input : torch.Tensor + The dummy input to analyze the topology constraints. Note that, the dummy_input + should on the same device with the model. + """ + + def __init__(self, model, config_list, dependency_aware=False, dummy_input=None): + super().__init__(model, config_list, pruning_algorithm='fpgm', dependency_aware=dependency_aware, + dummy_input=dummy_input) + + def _supported_dependency_aware(self): + return True diff --git a/nni/algorithms/compression/pytorch/pruning/structured_pruning.py b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py similarity index 98% rename from nni/algorithms/compression/pytorch/pruning/structured_pruning.py rename to nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py index 277bed4757..671811e138 100644 --- a/nni/algorithms/compression/pytorch/pruning/structured_pruning.py +++ b/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py @@ -474,8 +474,8 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): def __init__(self, model, pruner, statistics_batch_num=1): super().__init__(model, pruner) self.pruner.statistics_batch_num = statistics_batch_num - self.pruner.set_wrappers_attribute("contribution", None) self.pruner.iterations = 0 + self.pruner.set_wrappers_attribute("contribution", None) self.pruner.patch_optimizer(self.calc_contributions) def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): @@ -499,6 +499,7 @@ def calc_contributions(self): """ if self.pruner.iterations >= self.pruner.statistics_batch_num: return + for wrapper in self.pruner.get_modules_wrapper(): filters = wrapper.module.weight.size(0) contribution = ( @@ -677,16 +678,24 @@ class SlimPrunerMasker(WeightMasker): def __init__(self, model, pruner, **kwargs): super().__init__(model, pruner) + self.global_threshold = None + + def _get_global_threshold(self): weight_list = [] - for (layer, _) in pruner.get_modules_to_compress(): + for (layer, _) in self.pruner.get_modules_to_compress(): weight_list.append(layer.module.weight.data.abs().clone()) all_bn_weights = torch.cat(weight_list) - k = int(all_bn_weights.shape[0] * pruner.config_list[0]['sparsity']) + k = int(all_bn_weights.shape[0] * self.pruner.config_list[0]['sparsity']) self.global_threshold = torch.topk( all_bn_weights.view(-1), k, largest=False)[0].max() + print(f'set global threshold to {self.global_threshold}') def calc_mask(self, sparsity, wrapper, wrapper_idx=None): assert wrapper.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' + + if self.global_threshold is None: + self._get_global_threshold() + weight = wrapper.module.weight.data.clone() if wrapper.weight_mask is not None: # apply base mask for iterative pruning @@ -706,7 +715,6 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None): ), 'bias_mask': mask_bias.detach()} return mask - def least_square_sklearn(X, Y): from sklearn.linear_model import LinearRegression reg = LinearRegression(fit_intercept=False) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 62703d449b..dbd5e5b3c3 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -148,6 +148,7 @@ def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) self.quant_grad = QATGrad.apply modules_to_compress = self.get_modules_to_compress() + device = next(model.parameters()).device self.bound_model.register_buffer("steps", torch.Tensor([1])) for layer, config in modules_to_compress: layer.module.register_buffer("zero_point", torch.Tensor([0.0])) @@ -161,7 +162,7 @@ def __init__(self, model, config_list, optimizer=None): layer.module.register_buffer('activation_bit', torch.zeros(1)) layer.module.register_buffer('tracked_min_activation', torch.zeros(1)) layer.module.register_buffer('tracked_max_activation', torch.zeros(1)) - + self.bound_model.to(device) def _del_simulated_attr(self, module): """ @@ -359,7 +360,7 @@ def step_with_optimizer(self): """ override `compressor` `step` method, quantization only happens after certain number of steps """ - self.bound_model.steps +=1 + self.bound_model.steps += 1 class DoReFaQuantizer(Quantizer): @@ -370,10 +371,12 @@ class DoReFaQuantizer(Quantizer): def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) + device = next(model.parameters()).device modules_to_compress = self.get_modules_to_compress() for layer, config in modules_to_compress: if "weight" in config.get("quant_types", []): layer.module.register_buffer('weight_bit', torch.zeros(1)) + self.bound_model.to(device) def _del_simulated_attr(self, module): """ @@ -474,11 +477,13 @@ class BNNQuantizer(Quantizer): def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) + device = next(model.parameters()).device self.quant_grad = ClipGrad.apply modules_to_compress = self.get_modules_to_compress() for layer, config in modules_to_compress: if "weight" in config.get("quant_types", []): layer.module.register_buffer('weight_bit', torch.zeros(1)) + self.bound_model.to(device) def _del_simulated_attr(self, module): """ @@ -589,6 +594,7 @@ def __init__(self, model, config_list, optimizer=None): types of nn.module you want to apply quantization, eg. 'Conv2d' """ super().__init__(model, config_list, optimizer) + device = next(model.parameters()).device self.quant_grad = QuantForward() modules_to_compress = self.get_modules_to_compress() self.bound_model.register_buffer("steps", torch.Tensor([1])) @@ -631,6 +637,8 @@ def __init__(self, model, config_list, optimizer=None): self.optimizer.add_param_group({"params": layer.module.input_scale}) + self.bound_model.to(device) + @staticmethod def grad_scale(x, scale): """ diff --git a/nni/algorithms/compression/tensorflow/pruning/__init__.py b/nni/algorithms/compression/tensorflow/pruning/__init__.py index f8ac8ea9b9..c535fd7512 100644 --- a/nni/algorithms/compression/tensorflow/pruning/__init__.py +++ b/nni/algorithms/compression/tensorflow/pruning/__init__.py @@ -1 +1 @@ -from .one_shot import * +from .one_shot_pruner import * diff --git a/nni/algorithms/compression/tensorflow/pruning/one_shot.py b/nni/algorithms/compression/tensorflow/pruning/one_shot_pruner.py similarity index 100% rename from nni/algorithms/compression/tensorflow/pruning/one_shot.py rename to nni/algorithms/compression/tensorflow/pruning/one_shot_pruner.py diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index 08543caf1a..01b8bb24e4 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -8,7 +8,6 @@ _logger = logging.getLogger(__name__) - class LayerInfo: def __init__(self, name, module): self.module = module @@ -235,7 +234,6 @@ def _wrap_modules(self, layer, config): """ raise NotImplementedError() - def add_activation_collector(self, collector): self._fwd_hook_id += 1 self._fwd_hook_handles[self._fwd_hook_id] = [] @@ -264,6 +262,18 @@ def new_step(_, *args, **kwargs): if self.optimizer is not None: self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer) + def patch_optimizer_before(self, *tasks): + def patch_step(old_step): + def new_step(_, *args, **kwargs): + for task in tasks: + task() + # call origin optimizer step method + output = old_step(*args, **kwargs) + return output + return new_step + if self.optimizer is not None: + self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer) + class PrunerModuleWrapper(torch.nn.Module): def __init__(self, module, module_name, module_type, config, pruner): """ @@ -319,8 +329,6 @@ class Pruner(Compressor): def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) - if optimizer is not None: - self.patch_optimizer(self.update_mask) def compress(self): self.update_mask() @@ -386,7 +394,7 @@ def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=N """ assert model_path is not None, 'model_path must be specified' mask_dict = {} - self._unwrap_model() # used for generating correct state_dict name without wrapper state + self._unwrap_model() # used for generating correct state_dict name without wrapper state for wrapper in self.get_modules_wrapper(): weight_mask = wrapper.weight_mask @@ -433,6 +441,27 @@ def load_model_state_dict(self, model_state): else: self.bound_model.load_state_dict(model_state) + def get_pruned_weights(self, dim=0): + """ + Log the simulated prune sparsity. + + Parameters + ---------- + dim : int + the pruned dim. + """ + for _, wrapper in enumerate(self.get_modules_wrapper()): + weight_mask = wrapper.weight_mask + mask_size = weight_mask.size() + if len(mask_size) == 1: + index = torch.nonzero(weight_mask.abs() != 0).tolist() + else: + sum_idx = list(range(len(mask_size))) + sum_idx.remove(dim) + index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0).tolist() + _logger.info(f'simulated prune {wrapper.name} remain/total: {len(index)}/{weight_mask.size(dim)}') + + class QuantizerModuleWrapper(torch.nn.Module): def __init__(self, module, module_name, module_type, config, quantizer): """ @@ -549,7 +578,6 @@ def quantize_input(self, *inputs, wrapper, **kwargs): """ raise NotImplementedError('Quantizer must overload quantize_input()') - def _wrap_modules(self, layer, config): """ Create a wrapper forward function to replace the original one. @@ -571,8 +599,8 @@ def _wrap_modules(self, layer, config): return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self) - def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, \ - input_shape=None, device=None): + def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, + input_shape=None, device=None): """ This method helps save pytorch model, calibration config, onnx model in quantizer. @@ -671,6 +699,7 @@ def _quantize(cls, x, scale, zero_point): quantized x without clamped """ return ((x / scale) + zero_point).round() + @classmethod def get_bits_length(cls, config, quant_type): """ @@ -703,8 +732,8 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma grad_output : Tensor gradient of the output of quantization operation scale : Tensor - the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, - you can define different behavior for different types. + the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, + `QuantType.QUANT_OUTPUT`, you can define different behavior for different types. zero_point : Tensor zero_point for quantizing tensor qmin : Tensor diff --git a/nni/compression/pytorch/utils/mask_conflict.py b/nni/compression/pytorch/utils/mask_conflict.py index 8e37893ba4..e89372d60e 100644 --- a/nni/compression/pytorch/utils/mask_conflict.py +++ b/nni/compression/pytorch/utils/mask_conflict.py @@ -31,7 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): # if the input is the path of the mask_file assert os.path.exists(masks) masks = torch.load(masks) - assert len(masks) > 0, 'Mask tensor cannot be empty' + assert len(masks) > 0, 'Mask tensor cannot be empty' # if the user uses the model and dummy_input to trace the model, we # should get the traced model handly, so that, we only trace the # model once, GroupMaskConflict and ChannelMaskConflict will reuse @@ -181,10 +181,8 @@ def fix_mask(self): w_mask = self.masks[layername]['weight'] shape = w_mask.size() count = np.prod(shape[1:]) - all_ones = (w_mask.flatten(1).sum(-1) == - count).nonzero().squeeze(1).tolist() - all_zeros = (w_mask.flatten(1).sum(-1) == - 0).nonzero().squeeze(1).tolist() + all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist() + all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist() if len(all_ones) + len(all_zeros) < w_mask.size(0): # In fine-grained pruning, skip this layer _logger.info('Layers %s using fine-grained pruning', layername) @@ -198,7 +196,7 @@ def fix_mask(self): group_masked = [] for i in range(group): _start = step * i - _end = step * (i+1) + _end = step * (i + 1) _tmp_list = list( filter(lambda x: _start <= x and x < _end, all_zeros)) group_masked.append(_tmp_list) @@ -286,7 +284,7 @@ def fix_mask(self): 0, 2, 3) if self.conv_prune_dim == 0 else (1, 2, 3) channel_mask = (mask.abs().sum(tmp_sum_idx) != 0).int() channel_masks.append(channel_mask) - if (channel_mask.sum() * (mask.numel() / mask.shape[1-self.conv_prune_dim])).item() != (mask > 0).sum().item(): + if (channel_mask.sum() * (mask.numel() / mask.shape[1 - self.conv_prune_dim])).item() != (mask > 0).sum().item(): fine_grained = True else: raise RuntimeError( diff --git a/test/ut/sdk/test_compressor_torch.py b/test/ut/sdk/test_compressor_torch.py index b7b0b2019e..350e46025e 100644 --- a/test/ut/sdk/test_compressor_torch.py +++ b/test/ut/sdk/test_compressor_torch.py @@ -61,9 +61,8 @@ def test_torch_quantizer_modules_detection(self): def test_torch_level_pruner(self): model = TorchModel() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] - torch_pruner.LevelPruner(model, configure_list, optimizer).compress() + torch_pruner.LevelPruner(model, configure_list).compress() def test_torch_naive_quantizer(self): model = TorchModel() @@ -93,7 +92,7 @@ def test_torch_fpgm_pruner(self): model = TorchModel() config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}] - pruner = torch_pruner.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01)) + pruner = torch_pruner.FPGMPruner(model, config_list) model.conv2.module.weight.data = torch.tensor(w).float() masks = pruner.calc_mask(model.conv2) @@ -152,7 +151,7 @@ def test_torch_slim_pruner(self): config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}] model.bn1.weight.data = torch.tensor(w).float() model.bn2.weight.data = torch.tensor(-w).float() - pruner = torch_pruner.SlimPruner(model, config_list) + pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None) mask1 = pruner.calc_mask(model.bn1) mask2 = pruner.calc_mask(model.bn2) @@ -165,7 +164,7 @@ def test_torch_slim_pruner(self): config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}] model.bn1.weight.data = torch.tensor(w).float() model.bn2.weight.data = torch.tensor(w).float() - pruner = torch_pruner.SlimPruner(model, config_list) + pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None) mask1 = pruner.calc_mask(model.bn1) mask2 = pruner.calc_mask(model.bn2) @@ -202,8 +201,8 @@ def test_torch_taylorFOweight_pruner(self): model = TorchModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) - pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, statistics_batch_num=1) - + pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsity_training_epochs=1) + x = torch.rand((1, 1, 28, 28), requires_grad=True) model.conv1.module.weight.data = torch.tensor(w1).float() model.conv2.module.weight.data = torch.tensor(w2).float() @@ -345,7 +344,7 @@ def test_torch_pruner_validation(self): ], [ {'sparsity': 0.2 }, - {'sparsity': 0.6, 'op_names': 'abc' } + {'sparsity': 0.6, 'op_names': 'abc'} ] ] model = TorchModel() @@ -353,7 +352,13 @@ def test_torch_pruner_validation(self): for pruner_class in pruner_classes: for config_list in bad_configs: try: - pruner_class(model, config_list, optimizer) + kwargs = {} + if pruner_class in (torch_pruner.SlimPruner, torch_pruner.AGPPruner, torch_pruner.ActivationMeanRankFilterPruner, torch_pruner.ActivationAPoZRankFilterPruner): + kwargs = {'optimizer': None, 'trainer': None, 'criterion': None} + + print('kwargs', kwargs) + pruner_class(model, config_list, **kwargs) + print(config_list) assert False, 'Validation error should be raised for bad configuration' except schema.SchemaError: diff --git a/test/ut/sdk/test_dependecy_aware.py b/test/ut/sdk/test_dependecy_aware.py index 5918f502bf..5823d1a408 100644 --- a/test/ut/sdk/test_dependecy_aware.py +++ b/test/ut/sdk/test_dependecy_aware.py @@ -46,6 +46,24 @@ def generate_random_sparsity_v2(model): 'sparsity': sparsity}) return cfg_list +def train(model, criterion, optimizer, callback=None): + model.train() + device = next(model.parameters()).device + data = torch.randn(2, 3, 224, 224).to(device) + target = torch.tensor([0, 1]).long().to(device) + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + + # callback should be inserted between loss.backward() and optimizer.step() + if callback: + callback() + + optimizer.step() + +def trainer(model, optimizer, criterion, epoch, callback=None): + return train(model, criterion, optimizer, callback=callback) class DependencyawareTest(TestCase): @unittest.skipIf(torch.__version__ < "1.3.0", "not supported") @@ -55,6 +73,7 @@ def test_dependency_aware_pruning(self): sparsity = 0.7 cfg_list = [{'op_types': ['Conv2d'], 'sparsity':sparsity}] dummy_input = torch.ones(1, 3, 224, 224) + for model_name in model_zoo: for pruner in pruners: print('Testing on ', pruner) @@ -72,16 +91,12 @@ def test_dependency_aware_pruning(self): momentum=0.9, weight_decay=4e-5) criterion = torch.nn.CrossEntropyLoss() - tmp_pruner = pruner( - net, cfg_list, optimizer, dependency_aware=True, dummy_input=dummy_input) - # train one single batch so that the the pruner can collect the - # statistic - optimizer.zero_grad() - out = net(dummy_input) - batchsize = dummy_input.size(0) - loss = criterion(out, torch.zeros(batchsize, dtype=torch.int64)) - loss.backward() - optimizer.step() + if pruner == TaylorFOWeightFilterPruner: + tmp_pruner = pruner( + net, cfg_list, optimizer, trainer=trainer, criterion=criterion, dependency_aware=True, dummy_input=dummy_input) + else: + tmp_pruner = pruner( + net, cfg_list, dependency_aware=True, dummy_input=dummy_input) tmp_pruner.compress() tmp_pruner.export_model(MODEL_FILE, MASK_FILE) @@ -91,7 +106,7 @@ def test_dependency_aware_pruning(self): ms.speedup_model() for name, module in net.named_modules(): if isinstance(module, nn.Conv2d): - expected = int(ori_filters[name] * (1-sparsity)) + expected = int(ori_filters[name] * (1 - sparsity)) filter_diff = abs(expected - module.out_channels) errmsg = '%s Ori: %d, Expected: %d, Real: %d' % ( name, ori_filters[name], expected, module.out_channels) @@ -124,16 +139,13 @@ def test_dependency_aware_random_config(self): momentum=0.9, weight_decay=4e-5) criterion = torch.nn.CrossEntropyLoss() - tmp_pruner = pruner( - net, cfg_list, optimizer, dependency_aware=True, dummy_input=dummy_input) - # train one single batch so that the the pruner can collect the - # statistic - optimizer.zero_grad() - out = net(dummy_input) - batchsize = dummy_input.size(0) - loss = criterion(out, torch.zeros(batchsize, dtype=torch.int64)) - loss.backward() - optimizer.step() + + if pruner in (TaylorFOWeightFilterPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner): + tmp_pruner = pruner( + net, cfg_list, optimizer, trainer=trainer, criterion=criterion, dependency_aware=True, dummy_input=dummy_input) + else: + tmp_pruner = pruner( + net, cfg_list, dependency_aware=True, dummy_input=dummy_input) tmp_pruner.compress() tmp_pruner.export_model(MODEL_FILE, MASK_FILE) diff --git a/test/ut/sdk/test_model_speedup.py b/test/ut/sdk/test_model_speedup.py index ecbdb89e6d..9ce7a7cba9 100644 --- a/test/ut/sdk/test_model_speedup.py +++ b/test/ut/sdk/test_model_speedup.py @@ -17,7 +17,7 @@ from nni.compression.pytorch import ModelSpeedup, apply_compression_results from nni.algorithms.compression.pytorch.pruning import L1FilterPruner from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker -from nni.algorithms.compression.pytorch.pruning.one_shot import _StructuredFilterPruner +from nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner import DependencyAwarePruner torch.manual_seed(0) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -205,7 +205,7 @@ def calc_mask(self, sparsity, wrapper, wrapper_idx=None): return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias} -class L1ChannelPruner(_StructuredFilterPruner): +class L1ChannelPruner(DependencyAwarePruner): def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None): super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer, dependency_aware=dependency_aware, dummy_input=dummy_input) diff --git a/test/ut/sdk/test_pruners.py b/test/ut/sdk/test_pruners.py index d9b10d63a9..e6948f2677 100644 --- a/test/ut/sdk/test_pruners.py +++ b/test/ut/sdk/test_pruners.py @@ -42,13 +42,10 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'agp': { 'pruner_class': AGPPruner, 'config_list': [{ - 'initial_sparsity': 0., - 'final_sparsity': 0.8, - 'start_epoch': 0, - 'end_epoch': 10, - 'frequency': 1, + 'sparsity': 0.8, 'op_types': ['Conv2d'] }], + 'trainer': lambda model, optimizer, criterion, epoch: model, 'validators': [] }, 'slim': { @@ -57,6 +54,7 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'sparsity': 0.7, 'op_types': ['BatchNorm2d'] }], + 'trainer': lambda model, optimizer, criterion, epoch: model, 'validators': [ lambda model: validate_sparsity(model.bn1, 0.7, model.bias) ] @@ -97,6 +95,7 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'sparsity': 0.5, 'op_types': ['Conv2d'], }], + 'trainer': lambda model, optimizer, criterion, epoch: model, 'validators': [ lambda model: validate_sparsity(model.conv1, 0.5, model.bias) ] @@ -107,6 +106,7 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'sparsity': 0.5, 'op_types': ['Conv2d'], }], + 'trainer': lambda model, optimizer, criterion, epoch: model, 'validators': [ lambda model: validate_sparsity(model.conv1, 0.5, model.bias) ] @@ -117,6 +117,7 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'sparsity': 0.5, 'op_types': ['Conv2d'], }], + 'trainer': lambda model, optimizer, criterion, epoch: model, 'validators': [ lambda model: validate_sparsity(model.conv1, 0.5, model.bias) ] @@ -127,7 +128,7 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'sparsity': 0.5, 'op_types': ['Conv2d'] }], - 'short_term_fine_tuner': lambda model:model, + 'short_term_fine_tuner': lambda model: model, 'evaluator':lambda model: 0.9, 'validators': [] }, @@ -146,7 +147,7 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'sparsity': 0.5, 'op_types': ['Conv2d'], }], - 'trainer': lambda model, optimizer, criterion, epoch, callback : model, + 'trainer': lambda model, optimizer, criterion, epoch : model, 'validators': [ lambda model: validate_sparsity(model.conv1, 0.5, model.bias) ] @@ -158,7 +159,7 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'op_types': ['Conv2d'], }], 'base_algo': 'l1', - 'trainer': lambda model, optimizer, criterion, epoch, callback : model, + 'trainer': lambda model, optimizer, criterion, epoch : model, 'evaluator': lambda model: 0.9, 'dummy_input': torch.randn([64, 1, 28, 28]), 'validators': [] @@ -170,7 +171,7 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'op_types': ['Conv2d'], }], 'base_algo': 'l2', - 'trainer': lambda model, optimizer, criterion, epoch, callback : model, + 'trainer': lambda model, optimizer, criterion, epoch : model, 'evaluator': lambda model: 0.9, 'dummy_input': torch.randn([64, 1, 28, 28]), 'validators': [] @@ -182,7 +183,7 @@ def validate_sparsity(wrapper, sparsity, bias=False): 'op_types': ['Conv2d'], }], 'base_algo': 'fpgm', - 'trainer': lambda model, optimizer, criterion, epoch, callback : model, + 'trainer': lambda model, optimizer, criterion, epoch : model, 'evaluator': lambda model: 0.9, 'dummy_input': torch.randn([64, 1, 28, 28]), 'validators': [] @@ -206,88 +207,87 @@ def __init__(self, bias=True): def forward(self, x): return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1)) +class SimpleDataset: + def __getitem__(self, index): + return torch.randn(3, 32, 32), 1. + + def __len__(self): + return 1000 + +def train(model, train_loader, criterion, optimizer): + model.train() + device = next(model.parameters()).device + x = torch.randn(2, 1, 28, 28).to(device) + y = torch.tensor([0, 1]).long().to(device) + # print('hello...') + + for _ in range(2): + out = model(x) + loss = criterion(out, y) + optimizer.zero_grad() + loss.backward() + + optimizer.step() + def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz', 'netadapt', 'simulatedannealing', 'admm', 'autocompress_l1', 'autocompress_l2', 'autocompress_fpgm',], bias=True): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dummy_input = torch.randn(2, 1, 28, 28).to(device) + + criterion = torch.nn.CrossEntropyLoss() + train_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True) + + def trainer(model, optimizer, criterion, epoch): + return train(model, train_loader, criterion, optimizer) + for pruner_name in pruner_names: print('testing {}...'.format(pruner_name)) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = Model(bias=bias).to(device) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) config_list = prune_config[pruner_name]['config_list'] - x = torch.randn(2, 1, 28, 28).to(device) - y = torch.tensor([0, 1]).long().to(device) - out = model(x) - loss = F.cross_entropy(out, y) - optimizer.zero_grad() - loss.backward() - optimizer.step() - if pruner_name == 'netadapt': pruner = prune_config[pruner_name]['pruner_class'](model, config_list, short_term_fine_tuner=prune_config[pruner_name]['short_term_fine_tuner'], evaluator=prune_config[pruner_name]['evaluator']) elif pruner_name == 'simulatedannealing': pruner = prune_config[pruner_name]['pruner_class'](model, config_list, evaluator=prune_config[pruner_name]['evaluator']) + elif pruner_name in ('agp', 'slim', 'taylorfo', 'apoz', 'mean_activation'): + pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=trainer, optimizer=optimizer, criterion=criterion) elif pruner_name == 'admm': - pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer']) + pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=trainer) elif pruner_name.startswith('autocompress'): - pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x, base_algo=prune_config[pruner_name]['base_algo']) + pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], criterion=torch.nn.CrossEntropyLoss(), dummy_input=dummy_input, base_algo=prune_config[pruner_name]['base_algo']) else: - pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer) - pruner.compress() - - x = torch.randn(2, 1, 28, 28).to(device) - y = torch.tensor([0, 1]).long().to(device) - out = model(x) - loss = F.cross_entropy(out, y) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if pruner_name == 'taylorfo': - # taylorfo algorithm calculate contributions at first iteration(step), and do pruning - # when iteration >= statistics_batch_num (default 1) - optimizer.step() + pruner = prune_config[pruner_name]['pruner_class'](model, config_list) + pruner.compress() pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28), device=device) for v in prune_config[pruner_name]['validators']: v(model) - filePaths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', './search_history.csv', './search_result.json'] for f in filePaths: if os.path.exists(f): os.remove(f) -def _test_agp(pruning_algorithm): - model = Model() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - config_list = prune_config['agp']['config_list'] - pruner = AGPPruner(model, config_list, optimizer, pruning_algorithm=pruning_algorithm) - pruner.compress() - - x = torch.randn(2, 1, 28, 28) - y = torch.tensor([0, 1]).long() +def _test_agp(pruning_algorithm): + train_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True) + model = Model() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - for epoch in range(config_list[0]['start_epoch'], config_list[0]['end_epoch']+1): - pruner.update_epoch(epoch) - out = model(x) - loss = F.cross_entropy(out, y) - optimizer.zero_grad() - loss.backward() - optimizer.step() + def trainer(model, optimizer, criterion, epoch): + return train(model, train_loader, criterion, optimizer) - target_sparsity = pruner.compute_target_sparsity(config_list[0]) - actual_sparsity = (model.conv1.weight_mask == 0).sum().item() / model.conv1.weight_mask.numel() - # set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small. - assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2) + config_list = prune_config['agp']['config_list'] + pruner = AGPPruner(model, config_list, optimizer=optimizer, trainer=trainer, criterion=torch.nn.CrossEntropyLoss(), pruning_algorithm=pruning_algorithm) + pruner.compress() -class SimpleDataset: - def __getitem__(self, index): - return torch.randn(3, 32, 32), 1. + target_sparsity = pruner.compute_target_sparsity(config_list[0]) + actual_sparsity = (model.conv1.weight_mask == 0).sum().item() / model.conv1.weight_mask.numel() + # set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small. + assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2) - def __len__(self): - return 1000 class PrunerTestCase(TestCase): def test_pruners(self):