Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Refactor model pruning framework #2504

Merged
merged 80 commits into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from 77 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
3a45961
Merge pull request #31 from microsoft/master
chicm-ms Aug 6, 2019
633db43
Merge pull request #32 from microsoft/master
chicm-ms Sep 9, 2019
3e926f1
Merge pull request #33 from microsoft/master
chicm-ms Oct 8, 2019
f173789
Merge pull request #34 from microsoft/master
chicm-ms Oct 9, 2019
508850a
Merge pull request #35 from microsoft/master
chicm-ms Oct 9, 2019
5a0e9c9
Merge pull request #36 from microsoft/master
chicm-ms Oct 10, 2019
e7df061
Merge pull request #37 from microsoft/master
chicm-ms Oct 23, 2019
2175cef
Merge pull request #38 from microsoft/master
chicm-ms Oct 29, 2019
2ccbfbb
Merge pull request #39 from microsoft/master
chicm-ms Oct 30, 2019
b29cb0b
Merge pull request #40 from microsoft/master
chicm-ms Oct 30, 2019
4a3ba83
Merge pull request #41 from microsoft/master
chicm-ms Nov 4, 2019
c8a1148
Merge pull request #42 from microsoft/master
chicm-ms Nov 4, 2019
73c6101
Merge pull request #43 from microsoft/master
chicm-ms Nov 5, 2019
6a518a9
Merge pull request #44 from microsoft/master
chicm-ms Nov 11, 2019
a0d587f
Merge pull request #45 from microsoft/master
chicm-ms Nov 12, 2019
e905bfe
Merge pull request #46 from microsoft/master
chicm-ms Nov 14, 2019
4b266f3
Merge pull request #47 from microsoft/master
chicm-ms Nov 15, 2019
237ff4b
Merge pull request #48 from microsoft/master
chicm-ms Nov 21, 2019
682be01
Merge pull request #49 from microsoft/master
chicm-ms Nov 25, 2019
133af82
Merge pull request #50 from microsoft/master
chicm-ms Nov 25, 2019
71a8a25
Merge pull request #51 from microsoft/master
chicm-ms Nov 26, 2019
d2a73bc
Merge pull request #52 from microsoft/master
chicm-ms Nov 26, 2019
198cf5e
Merge pull request #53 from microsoft/master
chicm-ms Dec 5, 2019
cdbfaf9
Merge pull request #54 from microsoft/master
chicm-ms Dec 6, 2019
7e9b29e
Merge pull request #55 from microsoft/master
chicm-ms Dec 10, 2019
d00c46d
Merge pull request #56 from microsoft/master
chicm-ms Dec 10, 2019
de7d1fa
Merge pull request #57 from microsoft/master
chicm-ms Dec 11, 2019
1835ab0
Merge pull request #58 from microsoft/master
chicm-ms Dec 12, 2019
24fead6
Merge pull request #59 from microsoft/master
chicm-ms Dec 20, 2019
0b7321e
Merge pull request #60 from microsoft/master
chicm-ms Dec 23, 2019
60058d4
Merge pull request #61 from microsoft/master
chicm-ms Dec 23, 2019
b111a55
Merge pull request #62 from microsoft/master
chicm-ms Dec 24, 2019
611c337
Merge pull request #63 from microsoft/master
chicm-ms Dec 30, 2019
4a1f14a
Merge pull request #64 from microsoft/master
chicm-ms Jan 10, 2020
7a9e604
Merge pull request #65 from microsoft/master
chicm-ms Jan 14, 2020
b8035b0
Merge pull request #66 from microsoft/master
chicm-ms Feb 4, 2020
47567d3
Merge pull request #67 from microsoft/master
chicm-ms Feb 10, 2020
614d427
Merge pull request #68 from microsoft/master
chicm-ms Feb 10, 2020
a0d9ed6
Merge pull request #69 from microsoft/master
chicm-ms Feb 11, 2020
22dc1ad
Merge pull request #70 from microsoft/master
chicm-ms Feb 19, 2020
0856813
Merge pull request #71 from microsoft/master
chicm-ms Feb 22, 2020
9e97bed
Merge pull request #72 from microsoft/master
chicm-ms Feb 25, 2020
16a1b27
Merge pull request #73 from microsoft/master
chicm-ms Mar 3, 2020
e246633
Merge pull request #74 from microsoft/master
chicm-ms Mar 4, 2020
0439bc1
Merge pull request #75 from microsoft/master
chicm-ms Mar 17, 2020
8b5613a
Merge pull request #76 from microsoft/master
chicm-ms Mar 18, 2020
43e8d31
Merge pull request #77 from microsoft/master
chicm-ms Mar 22, 2020
aae448e
Merge pull request #78 from microsoft/master
chicm-ms Mar 25, 2020
7095716
Merge pull request #79 from microsoft/master
chicm-ms Mar 25, 2020
c51263a
Merge pull request #80 from microsoft/master
chicm-ms Apr 11, 2020
9953c70
Merge pull request #81 from microsoft/master
chicm-ms Apr 14, 2020
f9136c4
Merge pull request #82 from microsoft/master
chicm-ms Apr 16, 2020
b384ad2
Merge pull request #83 from microsoft/master
chicm-ms Apr 20, 2020
ff592dd
Merge pull request #84 from microsoft/master
chicm-ms May 12, 2020
0b5378f
Merge pull request #85 from microsoft/master
chicm-ms May 18, 2020
a53e0b0
Merge pull request #86 from microsoft/master
chicm-ms May 25, 2020
3ea0b89
Merge pull request #87 from microsoft/master
chicm-ms May 28, 2020
cf3fb20
Merge pull request #88 from microsoft/master
chicm-ms May 28, 2020
463c334
Refactor pruners
chicm-ms May 28, 2020
78d9dc8
updates
chicm-ms May 29, 2020
878a750
updates
chicm-ms May 29, 2020
0d53338
updates
chicm-ms Jun 2, 2020
c0eeb41
updates
chicm-ms Jun 2, 2020
2ab0c58
agp
chicm-ms Jun 2, 2020
df8df96
updates
chicm-ms Jun 2, 2020
9d5d884
updates
chicm-ms Jun 2, 2020
edf8785
updates
chicm-ms Jun 2, 2020
693f901
updates
chicm-ms Jun 2, 2020
d1ae471
updates
chicm-ms Jun 2, 2020
0cea379
updates
chicm-ms Jun 3, 2020
c5f7b3c
updates
chicm-ms Jun 3, 2020
6573858
updates
chicm-ms Jun 4, 2020
fae7227
updates
chicm-ms Jun 10, 2020
74c294f
update docs
chicm-ms Jun 11, 2020
07d9b65
updates
chicm-ms Jun 11, 2020
981995d
updates
chicm-ms Jun 11, 2020
8bb81c4
update docs
chicm-ms Jun 11, 2020
d3405b6
updates
chicm-ms Jun 11, 2020
1ba2d7a
updates
chicm-ms Jun 11, 2020
ce3141b
updates
chicm-ms Jun 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 73 additions & 58 deletions docs/en_US/Compressor/Framework.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
# Design Doc

## Overview
The model compression framework has two main components: `pruner` and `module wrapper`.
Following example shows how to use a pruner:
```python
from nni.compression.torch import LevelPruner

### pruner
A `pruner` is responsible for :
1. provide a `cal_mask` method that calculates masks for weight and bias.
2. replace the module with `module wrapper` based on config.
3. modify the optimizer so that the `cal_mask` method is called every time the `step` method is called.
# load a pretrained model or train a model before using a pruner

configure_list = [{
'sparsity': 0.7,
'op_types': ['Conv2d', 'Linear'],
}]

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = LevelPruner(model, configure_list, optimizer)
model = pruner.compress()

# model is ready for pruning, now start finetune the model,
# the model will be pruned during training automatically
```

A pruner receives model, config_list and optimizer as arguments. It prunes the model per user defined configuratioin (defined by `config_list`) during training loop by adding a hook on `optimizer.step()`.

From implementation perspective, a pruner consists of a `weight masker` instance and multiple `module wrapper` instances.

### module wrapper
### Weight masker
A `weight masker` is the implementation of pruning algorithms, it can prune a specified layer wrapped by `module wrapper` with specified sparsity.

### Module wrapper
A `module wrapper` is a module containing :
1. the origin module
2. some buffers used by `cal_mask`
Expand All @@ -19,84 +37,81 @@ the reasons to use `module wrapper` :
1. some buffers are needed by `cal_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated.
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method.

## How it works
A basic pruner usage:
```python
configure_list = [{
'sparsity': 0.7,
'op_types': ['BatchNorm2d'],
}]

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = SlimPruner(model, configure_list, optimizer)
model = pruner.compress()
```

A pruner receive model, config and optimizer as arguments. In the `__init__` method, the `step` method of the optimizer is replaced with a new `step` method that calls `cal_mask`. Also, all modules are checked if they need to be pruned based on config. If a module needs to be pruned, then this module is replaced by a `module wrapper`. Afterward, the new model and new optimizer are returned, which can be trained as before. `compress` method will calculate the default masks.
### Pruner
A `pruner` is responsible for :
1. Manage / verify config_list.
2. Use `module wrapper` to wrap the model layers and add hook on `optimizer.step`
3. Use `weight masker` to calculate masks of layers while pruning.
4. Export pruned model weights and masks.

## Implement a new pruning algorithm
Implementing a new pruning algorithm requires implementing a new `pruner` class, which should subclass `Pruner` and override the `cal_mask` method. The `cal_mask` is called by`optimizer.step` method.
The `Pruner` base class provided basic functionality listed above, for example, replacing modules and patching optimizer.
Implementing a new pruning algorithm requires implementing a `weight masker` class which shoud be a subclass of `WeightMasker`, and a `pruner` class, which should a subclass `Pruner`.

A basic pruner look likes this:
An implementation of `weight masker` may look like this:
```python
class NewPruner(Pruner):
def __init__(self, model, config_list, optimizer)
super().__init__(model, config_list, optimizer)
# do some initialization

def calc_mask(self, wrapper, **kwargs):
# do something to calculate weight_mask
wrapper.weight_mask = weight_mask
class MyMasker(WeightMasker):
def __init__(self, model, pruner):
super().__init__(model, pruner)
# You can do some initialization here, such as collecting some statistics data
# if it is necessary for your algorithms to calcuate the masks.

def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
# calculate the masks based on the wrapper.weight, and sparsity,
# and anything else
# mask = ...
return {'weight_mask': mask}
```
### Set wrapper attribute
Sometimes `cal_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`.
You can reference nni provided [weight masker](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/structured_pruning.py) implementations to implement your own weight masker.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coded link to source code is not a good idea. Recommend to use link to API docs instead. Still, keep it if you feel necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it for now.


A basic pruner look likes this:
```python
class NewPruner(Pruner):
class MyPruner(Pruner):
def __init__(self, model, config_list, optimizer):
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)

def calc_mask(self, wrapper):
# do something to calculate weight_mask
# construct a weight masker instance
self.masker = MyMasker(model, self)

def calc_mask(self, wrapper, wrapper_idx=None):
sparsity = wrapper.config['sparsity']
if wrapper.if_calculated:
pass
# Already pruned, do not prune again as a one-shot pruner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this pruner some kind of template pruner? If so, why should I write my own pruner instead of using existing pruners with my own masker?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an example, not a template, the implementation of most structured pruners look similar, but considering we are not changing the LevelPruner, L1FilterPruner interface, a template pruner is not provided in this PR.

return None
else:
# call your masker to actually calcuate the mask for this layer
masks = self.masker.calc_mask(sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
wrapper.if_calculated = True
# update masks
return masks

```
Reference nni provided [pruner](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py) implementations to implement your own pruner class.

### Set wrapper attribute
Sometimes `cal_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`.
In above example, we use `set_wrappers_attribute` to set a buffer `if_calculated` which is used as flag indicating if the mask of a layer is already calculated.

### Collect data during forward
Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. Therefore user can add a customized collector to module.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this designed for activations only? Other layers won't work?


```python
class ActivationRankFilterPruner(Pruner):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
self.set_wrappers_attribute("collected_activation", [])
self.statistics_batch_num = statistics_batch_num
class MyMasker(WeightMasker):
def __init__(self, model, pruner):
super().__init__(model, pruner)
self.pruner.set_wrappers_attribute("collected_activation", [])
self.pruner.activation = torch.nn.functional.relu

def collector(module_, input_, output):
if len(module_.collected_activation) < self.statistics_batch_num:
module_.collected_activation.append(self.activation(output.detach().cpu()))
self.add_activation_collector(collector)
assert activation in ['relu', 'relu6']
if activation == 'relu':
self.activation = torch.nn.functional.relu
elif activation == 'relu6':
self.activation = torch.nn.functional.relu6
else:
self.activation = None
module_.collected_activation.append(self.activation(output.detach().cpu()))

self.pruner.hook_id = self.pruner.add_activation_collector(collector)
```
The collector function will be called each time the forward method runs.

Users can also remove this collector like this:
```python
collector_id = self.add_activation_collector(collector)
collector_id = self.pruner.add_activation_collector(collector)
# ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does ... refer to? What should users do here?

self.remove_activation_collector(collector_id)
self.pruner.remove_activation_collector(collector_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this document written for internal use or for other people who are trying to write a new pruner?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both

```

### Multi-GPU support
Expand Down
12 changes: 11 additions & 1 deletion docs/en_US/Compressor/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,20 @@ config_list = [{
'frequency': 1,
'op_types': ['default']
}]
pruner = AGP_Pruner(model, config_list)
pruner = AGP_Pruner(model, config_list, pruning_algorithm='level')
pruner.compress()
```

AGP pruner uses `LevelPruner` algorithms to prune the weight by default, however you can set `pruning_algorithm` parameter to other values to use other pruning algorithms:
* `level`: LevelPruner
* `slim`: SlimPruner
* `l1`: L1FilterPruner
* `l2`: L2FilterPruner
* `fpgm`: FPGMPruner
* `taylorfo`: TaylorFOWeightFilterPruner
* `apoz`: ActivationAPoZRankFilterPruner
* `mean_activation`: ActivationMeanRankFilterPruner

you should add code below to update epoch number when you finish one epoch in your training code.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should...


Tensorflow code
Expand Down
9 changes: 5 additions & 4 deletions src/sdk/pynni/nni/compression/torch/pruning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .pruners import *
from .weight_rank_filter_pruners import *
from .activation_rank_filter_pruners import *
from .finegrained_pruning import *
from .structured_pruning import *
from .apply_compression import apply_compression_results
from .gradient_rank_filter_pruners import *
from .one_shot import *
from .agp import *
from .lottery_ticket import LotteryTicketPruner
Loading