Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Constraint-aware one-shot pruners #2657

Merged
merged 111 commits into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
3515b29
constraint-aware pruner
Jul 5, 2020
1c84925
Constrained Structure pruner.
Jul 6, 2020
3ed53cb
Constrained pruner.
Jul 6, 2020
a415206
Constrained one-shot pruner.
Jul 6, 2020
5fa19fc
Constraint aware pruner.
Jul 7, 2020
bed63fe
Constrained one-shot pruner.
Jul 7, 2020
12c24e5
Constrained one shot pruner.
Jul 7, 2020
45e62c2
Constrained-aware one-shot pruner.
Jul 7, 2020
aeb8aaf
Update the doc.
Jul 7, 2020
70dac7c
reformat the unit test.
Jul 7, 2020
5362328
Add test case for constrained-aware pruners.
Jul 8, 2020
bd375a4
Remove the unnecessary log function.
Jul 8, 2020
9d0fb79
fix pylint errors.
Jul 8, 2020
211a047
Add the docs for the constrained pruners.
Jul 13, 2020
a11cf48
empty commit
Jul 13, 2020
899b6f9
Merge branch 'master' of https://github.com/microsoft/nni into constr…
Jul 13, 2020
439426f
Add an accuracy comparsion benchmark for Constrained Pruner.
Jul 20, 2020
a08dac6
update
Jul 20, 2020
39eadcc
Merge branch 'master' of https://github.com/microsoft/nni into constr…
Jul 23, 2020
fab3315
update the benchmark
Jul 23, 2020
4bc3c48
Update constrained pruner benchmark.
Jul 24, 2020
6c32a7c
update
Jul 25, 2020
9923ea1
update
Jul 27, 2020
5a35c8e
update
Jul 28, 2020
00eb006
fix a bug.
Jul 28, 2020
65676ad
update.
Jul 28, 2020
50d3468
update
Jul 28, 2020
d358cb2
update
Aug 2, 2020
174233c
tmp branch
Aug 3, 2020
b36589c
update
Aug 3, 2020
482e500
update
Aug 4, 2020
e5b262e
support imagenet for auto_pruners_torch
Aug 4, 2020
2591abc
update
Aug 4, 2020
0ee36b8
and a switch for the constrained pruner
Aug 4, 2020
19df319
update
Aug 4, 2020
b2f03b7
update
Aug 4, 2020
e263349
update
Aug 4, 2020
fca51bf
update
Aug 5, 2020
0fabf61
update
Aug 5, 2020
1ac7050
update
Aug 6, 2020
e55fe10
bug in the sm pruner
Aug 6, 2020
c1f9a45
update
Aug 6, 2020
9aa1558
add one more mile stone
Aug 6, 2020
e9b39fb
fix a bug caused by the expand and clone
Aug 7, 2020
d7cc452
add a constrained switch for the auto compress pruner
Aug 7, 2020
9183b93
add support for imagenet
Aug 11, 2020
e3226ee
unfinish
Aug 21, 2020
eca8577
attention pruner unfinished
Aug 24, 2020
ad58382
update
Aug 25, 2020
90c1c47
merge from master
Aug 25, 2020
12c289e
update
Aug 26, 2020
4029b1c
update
Aug 26, 2020
f3098fb
update
Aug 27, 2020
1fd32f5
update
Aug 27, 2020
755ce8b
update
Aug 27, 2020
1639867
updata
Aug 27, 2020
fe78c59
update
Aug 27, 2020
5c54faf
update
Aug 27, 2020
f5d4060
update
Aug 28, 2020
e5bcd6a
add no dependency
Aug 29, 2020
f625f81
use softmax in the attention pruner
Aug 31, 2020
e5f3e01
update
Aug 31, 2020
fcc984c
update
Aug 31, 2020
13d4f38
update
Sep 1, 2020
b7bac26
update
Sep 2, 2020
b3d1ac9
update
Sep 3, 2020
90d1e45
add the unit test.
Sep 3, 2020
cf7c936
update
Sep 3, 2020
85fc79f
update
Sep 3, 2020
b593ba3
update
Sep 3, 2020
6682cb3
update
Sep 3, 2020
25beb8f
update doc string
Sep 4, 2020
649ecfd
update the documentation
Sep 4, 2020
fb09b3f
Remove the attention pruner.
Sep 4, 2020
da5525a
remove the mobilenet_v2 for cifar10
Sep 4, 2020
0920efe
reset the auto_pruners_torch.py
Sep 4, 2020
74f4ec4
update the example to the new interface.
Sep 4, 2020
45966c9
fix pylint errors
Sep 4, 2020
e02fb90
update the example
Sep 4, 2020
cf626f8
fix a bug when counting flops
Sep 7, 2020
f444ebc
add several new one-shot pruners
Sep 8, 2020
a0d1e97
support more one_shot prunersw
Sep 8, 2020
59f0fe1
test
Sep 8, 2020
68e4563
fix a bug in the original apoz pruner
Sep 8, 2020
bddb70f
update
Sep 9, 2020
646324a
update
Sep 9, 2020
c7ba084
update
Sep 9, 2020
7a54cd6
update
Sep 9, 2020
c42de2c
update
Sep 9, 2020
9b9bb09
update
Sep 10, 2020
82f4fdb
update the unit test
Sep 10, 2020
8438ff5
update the examples
Sep 10, 2020
f9028f5
rm the test_dependency_aware
Sep 10, 2020
8afa53a
update
Sep 10, 2020
5c6d60e
update
Sep 10, 2020
78f3fc6
update the doc
Sep 10, 2020
2125e98
update rst
Sep 10, 2020
ae62671
Merge branch 'master' of https://github.com/microsoft/nni into constr…
Sep 11, 2020
8963a01
update
Sep 11, 2020
3691f23
Merge branch 'master' of https://github.com/microsoft/nni into constr…
Sep 13, 2020
29029ee
update doc
Sep 14, 2020
a8f3f74
update the doc
Sep 14, 2020
9bf7667
update doc
Sep 14, 2020
b7b7150
update the doc
Sep 14, 2020
e68cec0
update
Sep 16, 2020
c9c5329
update
Sep 16, 2020
ef51a10
update
Sep 16, 2020
4acabaa
update
Sep 16, 2020
80aec67
add some evaluation results
Sep 16, 2020
3a73e30
update
Sep 21, 2020
d5bbe48
update the doc
Sep 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ We provide several pruning algorithms that support fine-grained weight pruning a
* [Slim Pruner](#slim-pruner)
* [FPGM Pruner](#fpgm-pruner)
* [L1Filter Pruner](#l1filter-pruner)
* [Constrained L1Filter Pruner](#constrained-l1filter-pruner)
* [L2Filter Pruner](#l2filter-pruner)
* [Constrained L2Filter Pruner](#constrained-l2filter-pruner)
* [APoZ Rank Pruner](#activationapozrankfilterpruner)
* [Activation Mean Rank Pruner](#activationmeanrankfilterpruner)
* [Constrained Activation Mean Rank Filter Pruner](#constrained-activationmeanrankfilter-pruner)
* [Taylor FO On Weight Pruner](#taylorfoweightfilterpruner)

**Pruning Schedule**
Expand Down Expand Up @@ -177,6 +180,27 @@ The experiments code can be found at [examples/model_compress]( https://github.c

***

## Constrained L1Filter Pruner
This is a topology constraint-aware one-shot pruner. Compared to the [original L1 Filter Pruner](#l1filter-pruner), this pruner prunes the model not only based on the l1 norm of each filter, but also the topology of the network architecture of the target model. Specifically, for the example, if the output channels of two convolutional layers(conv1, conv2) are added together, then we can say that these two conv layers have channel dependency with each other(more details please see [Compression Utils](./CompressionUtils.md)). If we prune the first 50% of output channels(filters) for conv1, and prune the last 50% of output channels for conv2. Although both layers have pruned 50% of the filters, the speedup module still needs to add zeros to align the output channels. In this case, we cannot harvest the speed benefit from the model pruning. To better gain the speed benefit of the model pruning, we develop this constraint(topology)-aware one-shot pruner.

The `Constrained L1Filter Pruner` will try to prune the same output channels for the layers that have the channel dependencies with each other. `Constrained L1Filter Pruner` will calculate the L1 norm sum of all the layers in the dependency set for each channel. We know that the maximum sparsity of the channels of this dependency set is determined by the minimum sparsity of layers in this dependency set(denoted by `min_sparsity`). According to the L1 norm sum of each channel, `Constrained L1Filter Pruner` will prune the same `min_sparsity` channels for all the layers. Next, the pruner will additionally prune `sparsity` - `min_sparsity` channels for each convolutional layer based on its own L1 norm of each channel. For example, suppose the output channels of `conv1` , `conv2` are added together and the configured sparsities of `conv1` and `conv2` are 0.3, 0.2 respectively. In this case, `Constrained L1Filter Pruner` will prune the same 20% of channels for `conv1` and `conv2` according to L1 norm sum of `conv1` and `conv2`. Next, the pruner will additionally prune 10% channels for `conv1` according to the L1 norm of each channel of `conv1`.

In addition. for the convolutional layers that have more than one filter group, `Constrained L1Filter Pruner` will also try to prune the same number of the channels for each filter group. Overall, this pruner will prune the model according to the L1 norm of each filter and try to meet the topological constrains(channel dependency, etc) to improve the final speed gain after the speedup process.

In a word, compared to `L1Filter`, `Constrained L1Filter Pruner` will provide a better speed gain from the model pruning.


### Usage
Pytorch code
```python
from nni.compression.torch import Constrained_L1FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
dummy_input = torch.rand(1, 3, 32, 32)
pruner = Constrained_L1FilterPruner(model, config_list, dummy_input)
pruner.compress()
```
Compared to `L1FilterPruner`, `ConstrainedL1FilterPruner` needs an additional input parameter called `dummy_input` to analyze the topology of the input model. The other input parameters are same as `L1FilterPruner`.

## L2Filter Pruner

This is a structured pruning algorithm that prunes the filters with the smallest L2 norm of the weights. It is implemented as a one-shot pruner.
Expand All @@ -199,6 +223,19 @@ pruner.compress()

***

## Constrained L2Filter Pruner
Similar to Constrained L1Filter Pruner, this pruner prunes the model based on the L2 norm and the topology of the model.

### Usage
Pytorch code
```python
from nni.compression.torch import Constrained_L2FilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
dummy_input = torch.rand(1, 3, 32, 32)
pruner = Constrained_L2FilterPruner(model, config_list, dummy_input)
pruner.compress()
```

## ActivationAPoZRankFilterPruner

ActivationAPoZRankFilterPruner is a pruner which prunes the filters with the smallest importance criterion `APoZ` calculated from the output activations of convolution layers to achieve a preset level of network sparsity. The pruning criterion `APoZ` is explained in the paper [Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures](https://arxiv.org/abs/1607.03250).
Expand Down Expand Up @@ -261,6 +298,20 @@ You can view example for more information

***


## Constrained ActivationMeanRankFilter Pruner
Similar to Constrained L1Filter Pruner, this pruner prunes the model based on the activation rank of the filters and the topology of the model.

### Usage
Pytorch code
```python
from nni.compression.torch import ConstrainedActivationMeanRankFilterPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
dummy_input = torch.rand(1, 3, 32, 32)
pruner = ConstrainedActivationMeanRankFilterPruner(model, config_list, dummy_input)
pruner.compress()
```

## TaylorFOWeightFilterPruner

TaylorFOWeightFilterPruner is a pruner which prunes convolutional layers based on estimated importance calculated from the first order taylor expansion on weights to achieve a preset level of network sparsity. The estimated importance of filters is defined as the paper [Importance Estimation for Neural Network Pruning](http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf). Other pruning criteria mentioned in this paper will be supported in future release.
Expand Down
219 changes: 219 additions & 0 deletions examples/model_compress/constrained_pruner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
Examples for automatic pruners
'''

import argparse
import os
import json
import torch
from torch.optim.lr_scheduler import StepLR, MultiStepLR
from torchvision import datasets, transforms, models

from models.mnist.lenet import LeNet
from models.cifar10.vgg import VGG
from nni.compression.torch import L1FilterPruner, Constrained_L1FilterPruner
from nni.compression.torch import L2FilterPruner, Constrained_L2FilterPruner
from nni.compression.torch import ActivationMeanRankFilterPruner, ConstrainedActivationMeanRankFilterPruner
from nni.compression.torch import ModelSpeedup
from nni.compression.torch.utils.counter import count_flops_params

def cifar10_dataset(args):
"""
return the train & test dataloader for the cifar10 dataset.
"""
kwargs = {'num_workers': 10, 'pin_memory': True} if torch.cuda.is_available() else {
}


normalize = transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(args.data_dir, train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=args.batch_size, shuffle=True, **kwargs)

val_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(args.data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False, **kwargs)
dummy_input = torch.ones(1, 3, 32, 32)
return train_loader, val_loader, dummy_input

def imagenet_dataset(args):
kwargs = {'num_workers': 10, 'pin_memory': True} if torch.cuda.is_available() else {}
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(os.path.join(args.data_dir, 'train'),
transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=True, **kwargs)

val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(os.path.join(args.data_dir, 'val'),
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
dummy_input = torch.ones(1, 3, 224, 224)
return train_loader, val_loader, dummy_input

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='imagenet',
help='dataset to use, mnist, cifar10 or imagenet (default cifar10)')
parser.add_argument('--model', type=str, default='resnet18',
help='model to use, LeNet, vgg16, resnet18 or mobilenet_v2')
parser.add_argument('--data-dir', type=str, default='/mnt/imagenet/raw_jpeg/2012/',
help='dataset directory')
parser.add_argument('--batch-size', type=int, default=64,
help='input batch size for training (default: 64)')
parser.add_argument('--sparsity', type=float, default=0.1,
help='overall target sparsity')
parser.add_argument('--log-interval', type=int, default=200,
help='how many batches to wait before logging training status')
parser.add_argument('--finetune_epochs', type=int, default=15,
help='the number of finetune epochs after pruning')
parser.add_argument('--lr', type=float, default=0.001, help='the learning rate of model')
return parser.parse_args()


def train(args, model, device, train_loader, criterion, optimizer, epoch, callback=None):
model.train()
loss_sum = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss_sum += loss.item()
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss_sum/(batch_idx+1)))


def test(model, device, criterion, val_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += criterion(output, target).item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(val_loader.dataset)
accuracy = correct / len(val_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(val_loader.dataset), 100. * accuracy))

return accuracy

def get_data(args):
if args.dataset == 'cifar10':
return cifar10_dataset(args)
elif args.dataset == 'imagenet':
return imagenet_dataset(args)

if __name__ == '__main__':
args = parse_args()
torch.manual_seed(0)
Model = getattr(models, args.model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, val_loader, dummy_input = get_data(args)
net1 = Model(pretrained=True).to(device)
net2 = Model(pretrained=True).to(device)

optimizer1 = torch.optim.SGD(net1.parameters(), lr=args.lr,
momentum=0.9,
weight_decay=5e-4)
scheduler1 = MultiStepLR(
optimizer1, milestones=[int(args.finetune_epochs*0.5), int(args.finetune_epochs*0.75)], gamma=0.1)
criterion1 = torch.nn.CrossEntropyLoss()
optimizer2 = torch.optim.SGD(net2.parameters(), lr=args.lr,
momentum=0.9,
weight_decay=5e-4)
scheduler2 = MultiStepLR(
optimizer2, milestones=[int(args.finetune_epochs*0.5), int(args.finetune_epochs*0.75)], gamma=0.1)
criterion2 = torch.nn.CrossEntropyLoss()

cfglist = [{'op_types':['Conv2d'], 'sparsity':args.sparsity}]
#pruner1 = L1FilterPruner(net1, cfglist, optimizer1)
#pruner2 = Constrained_L1FilterPruner(net2, cfglist, dummy_input.to(device), optimizer2)

pruner1 = ActivationMeanRankFilterPruner(net1, cfglist, optimizer1, statistics_batch_num=10)
pruner2 = ConstrainedActivationMeanRankFilterPruner(net2, cfglist, dummy_input.to(device), optimizer2, statistics_batch_num=10)
for batch_idx, (data, target) in enumerate(train_loader):
data = data.to(device)
net1(data)
net2(data)
if batch_idx > 10:
# enough data to calculate the activation
break

pruner1.compress()
pruner2.compress()
pruner1.export_model('./ori_%f.pth' % args.sparsity, './ori_mask_%f' % args.sparsity)
pruner2.export_model('./cons_%f.pth' % args.sparsity, './cons_mask_%f' % args.sparsity)
pruner1._unwrap_model()
pruner2._unwrap_model()
ms1 = ModelSpeedup(net1, dummy_input.to(device), './ori_mask_%f' % args.sparsity)
ms2 = ModelSpeedup(net2, dummy_input.to(device), './cons_mask_%f' % args.sparsity)
ms1.speedup_model()
ms2.speedup_model()
print('Model speedup finished')

acc1 = test(net1, device, criterion1, val_loader)
acc2 = test(net2, device, criterion2, val_loader)
print('After pruning: Acc of Original Pruner %f, Acc of Constrained Pruner %f' % (acc1, acc2))

for epoch in range(args.finetune_epochs):
train(args, net2, device, train_loader,
criterion2, optimizer2, epoch)
scheduler2.step()
acc2 = test(net2, device, criterion2, val_loader)
print('Finetune Epoch %d, acc of constrained pruner %f'%(epoch, acc2))

for epoch in range(args.finetune_epochs):
train(args, net1, device, train_loader,
criterion1, optimizer1, epoch)
scheduler1.step()
acc1 = test(net1, device, criterion1, val_loader)
print('Finetune Epoch %d, acc of original pruner %f'%(epoch, acc1))



acc1 = test(net1, device, criterion1, val_loader)
acc2 = test(net2, device, criterion2, val_loader)
print('After finetuning: Acc of Original Pruner %f, Acc of Constrained Pruner %f' % (acc1, acc2))

flops1, weights1 = count_flops_params(net1, dummy_input.size())
flops2, weights2 = count_flops_params(net2, dummy_input.size())
print('L1filter pruner flops:{} weight:{}'.format(flops1, weights1))
print('Constrained L1filter pruner flops:{} weight:{}'.format(flops2, weights2))
9 changes: 7 additions & 2 deletions src/sdk/pynni/nni/compression/torch/pruning/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@

from ..pruning import LevelPrunerMasker, SlimPrunerMasker, L1FilterPrunerMasker, \
L2FilterPrunerMasker, FPGMPrunerMasker, TaylorFOWeightFilterPrunerMasker, \
ActivationAPoZRankFilterPrunerMasker, ActivationMeanRankFilterPrunerMasker
ActivationAPoZRankFilterPrunerMasker, ActivationMeanRankFilterPrunerMasker, \
L1ConstrainedFilterPrunerMasker, L2ConstrainedFilterPrunerMasker, \
ConstrainedActivationMeanRankFilterPrunerMasker

MASKER_DICT = {
'level': LevelPrunerMasker,
'slim': SlimPrunerMasker,
'l1': L1FilterPrunerMasker,
'l1_constrained': L1ConstrainedFilterPrunerMasker,
'l2': L2FilterPrunerMasker,
'l2_constrained': L2ConstrainedFilterPrunerMasker,
'fpgm': FPGMPrunerMasker,
'taylorfo': TaylorFOWeightFilterPrunerMasker,
'apoz': ActivationAPoZRankFilterPrunerMasker,
'mean_activation': ActivationMeanRankFilterPrunerMasker
'mean_activation': ActivationMeanRankFilterPrunerMasker,
'mean_activation_constrained': ConstrainedActivationMeanRankFilterPrunerMasker
}
Loading