-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Refactor model pruning framework #2504
Changes from all commits
3a45961
633db43
3e926f1
f173789
508850a
5a0e9c9
e7df061
2175cef
2ccbfbb
b29cb0b
4a3ba83
c8a1148
73c6101
6a518a9
a0d587f
e905bfe
4b266f3
237ff4b
682be01
133af82
71a8a25
d2a73bc
198cf5e
cdbfaf9
7e9b29e
d00c46d
de7d1fa
1835ab0
24fead6
0b7321e
60058d4
b111a55
611c337
4a1f14a
7a9e604
b8035b0
47567d3
614d427
a0d9ed6
22dc1ad
0856813
9e97bed
16a1b27
e246633
0439bc1
8b5613a
43e8d31
aae448e
7095716
c51263a
9953c70
f9136c4
b384ad2
ff592dd
0b5378f
a53e0b0
3ea0b89
cf3fb20
463c334
78d9dc8
878a750
0d53338
c0eeb41
2ab0c58
df8df96
9d5d884
edf8785
693f901
d1ae471
0cea379
c5f7b3c
6573858
fae7227
74c294f
07d9b65
981995d
8bb81c4
d3405b6
1ba2d7a
ce3141b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,144 @@ | ||
# Design Doc | ||
|
||
## Overview | ||
The model compression framework has two main components: `pruner` and `module wrapper`. | ||
|
||
### pruner | ||
A `pruner` is responsible for : | ||
1. provide a `cal_mask` method that calculates masks for weight and bias. | ||
2. replace the module with `module wrapper` based on config. | ||
3. modify the optimizer so that the `cal_mask` method is called every time the `step` method is called. | ||
Following example shows how to use a pruner: | ||
|
||
### module wrapper | ||
A `module wrapper` is a module containing : | ||
1. the origin module | ||
2. some buffers used by `cal_mask` | ||
3. a new forward method that applies masks before running the original forward method. | ||
```python | ||
from nni.compression.torch import LevelPruner | ||
|
||
the reasons to use `module wrapper` : | ||
1. some buffers are needed by `cal_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated. | ||
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method. | ||
# load a pretrained model or train a model before using a pruner | ||
|
||
## How it works | ||
A basic pruner usage: | ||
```python | ||
configure_list = [{ | ||
'sparsity': 0.7, | ||
'op_types': ['BatchNorm2d'], | ||
'op_types': ['Conv2d', 'Linear'], | ||
}] | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) | ||
pruner = SlimPruner(model, configure_list, optimizer) | ||
pruner = LevelPruner(model, configure_list, optimizer) | ||
model = pruner.compress() | ||
|
||
# model is ready for pruning, now start finetune the model, | ||
# the model will be pruned during training automatically | ||
``` | ||
|
||
A pruner receive model, config and optimizer as arguments. In the `__init__` method, the `step` method of the optimizer is replaced with a new `step` method that calls `cal_mask`. Also, all modules are checked if they need to be pruned based on config. If a module needs to be pruned, then this module is replaced by a `module wrapper`. Afterward, the new model and new optimizer are returned, which can be trained as before. `compress` method will calculate the default masks. | ||
A pruner receives `model`, `config_list` and `optimizer` as arguments. It prunes the model per the `config_list` during training loop by adding a hook on `optimizer.step()`. | ||
|
||
From implementation perspective, a pruner consists of a `weight masker` instance and multiple `module wrapper` instances. | ||
|
||
### Weight masker | ||
|
||
A `weight masker` is the implementation of pruning algorithms, it can prune a specified layer wrapped by `module wrapper` with specified sparsity. | ||
|
||
### Module wrapper | ||
|
||
A `module wrapper` is a module containing: | ||
|
||
1. the origin module | ||
2. some buffers used by `calc_mask` | ||
3. a new forward method that applies masks before running the original forward method. | ||
|
||
the reasons to use `module wrapper`: | ||
|
||
1. some buffers are needed by `calc_mask` to calculate masks and these buffers should be registered in `module wrapper` so that the original modules are not contaminated. | ||
2. a new `forward` method is needed to apply masks to weight before calling the real `forward` method. | ||
|
||
### Pruner | ||
|
||
A `pruner` is responsible for: | ||
|
||
1. Manage / verify config_list. | ||
2. Use `module wrapper` to wrap the model layers and add hook on `optimizer.step` | ||
3. Use `weight masker` to calculate masks of layers while pruning. | ||
4. Export pruned model weights and masks. | ||
|
||
## Implement a new pruning algorithm | ||
Implementing a new pruning algorithm requires implementing a new `pruner` class, which should subclass `Pruner` and override the `cal_mask` method. The `cal_mask` is called by`optimizer.step` method. | ||
The `Pruner` base class provided basic functionality listed above, for example, replacing modules and patching optimizer. | ||
|
||
A basic pruner look likes this: | ||
```python | ||
class NewPruner(Pruner): | ||
def __init__(self, model, config_list, optimizer) | ||
super().__init__(model, config_list, optimizer) | ||
# do some initialization | ||
Implementing a new pruning algorithm requires implementing a `weight masker` class which shoud be a subclass of `WeightMasker`, and a `pruner` class, which should a subclass `Pruner`. | ||
|
||
An implementation of `weight masker` may look like this: | ||
|
||
def calc_mask(self, wrapper, **kwargs): | ||
# do something to calculate weight_mask | ||
wrapper.weight_mask = weight_mask | ||
```python | ||
class MyMasker(WeightMasker): | ||
def __init__(self, model, pruner): | ||
super().__init__(model, pruner) | ||
# You can do some initialization here, such as collecting some statistics data | ||
# if it is necessary for your algorithms to calculate the masks. | ||
|
||
def calc_mask(self, sparsity, wrapper, wrapper_idx=None): | ||
# calculate the masks based on the wrapper.weight, and sparsity, | ||
# and anything else | ||
# mask = ... | ||
return {'weight_mask': mask} | ||
``` | ||
### Set wrapper attribute | ||
Sometimes `cal_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`. | ||
|
||
You can reference nni provided [weight masker](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/structured_pruning.py) implementations to implement your own weight masker. | ||
|
||
A basic pruner looks likes this: | ||
|
||
```python | ||
class NewPruner(Pruner): | ||
class MyPruner(Pruner): | ||
def __init__(self, model, config_list, optimizer): | ||
super().__init__(model, config_list, optimizer) | ||
self.set_wrappers_attribute("if_calculated", False) | ||
|
||
def calc_mask(self, wrapper): | ||
# do something to calculate weight_mask | ||
# construct a weight masker instance | ||
self.masker = MyMasker(model, self) | ||
|
||
def calc_mask(self, wrapper, wrapper_idx=None): | ||
sparsity = wrapper.config['sparsity'] | ||
if wrapper.if_calculated: | ||
pass | ||
# Already pruned, do not prune again as a one-shot pruner | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this pruner some kind of template pruner? If so, why should I write my own pruner instead of using existing pruners with my own masker? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an example, not a template, the implementation of most structured pruners look similar, but considering we are not changing the LevelPruner, L1FilterPruner interface, a template pruner is not provided in this PR. |
||
return None | ||
else: | ||
# call your masker to actually calcuate the mask for this layer | ||
masks = self.masker.calc_mask(sparsity=sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx) | ||
wrapper.if_calculated = True | ||
# update masks | ||
return masks | ||
|
||
``` | ||
|
||
Reference nni provided [pruner](https://github.com/microsoft/nni/blob/master/src/sdk/pynni/nni/compression/torch/pruning/one_shot.py) implementations to implement your own pruner class. | ||
|
||
### Set wrapper attribute | ||
|
||
Sometimes `calc_mask` must save some state data, therefore users can use `set_wrappers_attribute` API to register attribute just like how buffers are registered in PyTorch modules. These buffers will be registered to `module wrapper`. Users can access these buffers through `module wrapper`. | ||
In above example, we use `set_wrappers_attribute` to set a buffer `if_calculated` which is used as flag indicating if the mask of a layer is already calculated. | ||
|
||
### Collect data during forward | ||
Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. Therefore user can add a customized collector to module. | ||
|
||
Sometimes users want to collect some data during the modules' forward method, for example, the mean value of the activation. This can be done by adding a customized collector to module. | ||
|
||
```python | ||
class ActivationRankFilterPruner(Pruner): | ||
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1): | ||
super().__init__(model, config_list, optimizer) | ||
self.set_wrappers_attribute("if_calculated", False) | ||
self.set_wrappers_attribute("collected_activation", []) | ||
self.statistics_batch_num = statistics_batch_num | ||
|
||
def collector(module_, input_, output): | ||
if len(module_.collected_activation) < self.statistics_batch_num: | ||
module_.collected_activation.append(self.activation(output.detach().cpu())) | ||
self.add_activation_collector(collector) | ||
assert activation in ['relu', 'relu6'] | ||
if activation == 'relu': | ||
self.activation = torch.nn.functional.relu | ||
elif activation == 'relu6': | ||
self.activation = torch.nn.functional.relu6 | ||
else: | ||
self.activation = None | ||
class MyMasker(WeightMasker): | ||
def __init__(self, model, pruner): | ||
super().__init__(model, pruner) | ||
# Set attribute `collected_activation` for all wrappers to store | ||
# activations for each layer | ||
self.pruner.set_wrappers_attribute("collected_activation", []) | ||
self.activation = torch.nn.functional.relu | ||
|
||
def collector(wrapper, input_, output): | ||
# The collected activation can be accessed via each wrapper's collected_activation | ||
# attribute | ||
wrapper.collected_activation.append(self.activation(output.detach().cpu())) | ||
|
||
self.pruner.hook_id = self.pruner.add_activation_collector(collector) | ||
chicm-ms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
|
||
The collector function will be called each time the forward method runs. | ||
|
||
Users can also remove this collector like this: | ||
|
||
```python | ||
collector_id = self.add_activation_collector(collector) | ||
# ... | ||
self.remove_activation_collector(collector_id) | ||
# Save the collector identifier | ||
collector_id = self.pruner.add_activation_collector(collector) | ||
|
||
# When the collector is not used any more, it can be remove using | ||
# the saved collector identifier | ||
self.pruner.remove_activation_collector(collector_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this document written for internal use or for other people who are trying to write a new pruner? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think both |
||
``` | ||
|
||
### Multi-GPU support | ||
|
||
On multi-GPU training, buffers and parameters are copied to multiple GPU every time the `forward` method runs on multiple GPU. If buffers and parameters are updated in the `forward` method, an `in-place` update is needed to ensure the update is effective. | ||
Since `cal_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally. | ||
Since `calc_mask` is called in the `optimizer.step` method, which happens after the `forward` method and happens only on one GPU, it supports multi-GPU naturally. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from .pruners import * | ||
from .weight_rank_filter_pruners import * | ||
from .activation_rank_filter_pruners import * | ||
from .finegrained_pruning import * | ||
from .structured_pruning import * | ||
from .apply_compression import apply_compression_results | ||
from .gradient_rank_filter_pruners import * | ||
from .one_shot import * | ||
from .agp import * | ||
from .lottery_ticket import LotteryTicketPruner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hard-coded link to source code is not a good idea. Recommend to use link to API docs instead. Still, keep it if you feel necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it for now.