Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

add new QAT_quantization #1732

Merged
merged 58 commits into from
Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
400e312
fix tools/nni_gpu_tool pylint
Oct 31, 2019
d0f65c8
Merge branch 'master' of https://github.com/microsoft/nni
Oct 31, 2019
5d1a115
add instrument_layer_hook && fix quantization param calc
Nov 3, 2019
cb89e16
Merge branch 'master' of https://github.com/microsoft/nni
Nov 3, 2019
493b0f3
add QAT example
Nov 4, 2019
bd809a7
remove data
Nov 4, 2019
c41a6cc
modify framework
Cjkkkk Nov 4, 2019
1b02f6a
rm irrelevant files
Cjkkkk Nov 4, 2019
61d471e
fix pylint for QAT quantizer
Cjkkkk Nov 4, 2019
3a3f3ce
resolve conflicts
Cjkkkk Nov 4, 2019
f884f2a
API refactor
Cjkkkk Nov 4, 2019
47d639f
warning for no weight parameter
Cjkkkk Nov 4, 2019
985dc43
API refactor
Cjkkkk Nov 4, 2019
c2e3871
fix pylint for QAT_torch_quantizer.py
Cjkkkk Nov 4, 2019
71c3369
init modules_to_compress to None
Cjkkkk Nov 4, 2019
2897613
modify config
Cjkkkk Nov 5, 2019
e837314
add doc string for QAT_quantizer
Cjkkkk Nov 5, 2019
e229624
rename quant_delay to quant_start_step
Cjkkkk Nov 5, 2019
bca4b51
remove EMA
Cjkkkk Nov 5, 2019
1ddba06
add docstring to explain dequantize in dequantize method
Cjkkkk Nov 5, 2019
439cdec
fix typo
Cjkkkk Nov 5, 2019
34f3b62
fix
Cjkkkk Nov 5, 2019
8ef4c4b
change to stateless
Cjkkkk Nov 5, 2019
4a1d122
update doc
Cjkkkk Nov 5, 2019
6a29346
fix
Cjkkkk Nov 6, 2019
40f9062
remove return name in docstring
Cjkkkk Nov 6, 2019
81bb549
fix test
Cjkkkk Nov 6, 2019
2c09777
update doc
Cjkkkk Nov 6, 2019
ce7bea6
fix docstring
Cjkkkk Nov 7, 2019
8102990
fix compressor doc
Cjkkkk Nov 11, 2019
d20c820
fix wrong return statement & restore doc
Cjkkkk Nov 11, 2019
5188451
fix
Cjkkkk Nov 12, 2019
4d5a65b
fix
Cjkkkk Nov 12, 2019
9cfd04b
fix name convention
Cjkkkk Nov 13, 2019
6554b3a
fix quant_bits
Cjkkkk Nov 14, 2019
ecef611
fix quant_bits
Cjkkkk Nov 14, 2019
4b33383
fix pylint
Cjkkkk Nov 15, 2019
8b48745
fix shift error
Cjkkkk Nov 15, 2019
8026520
fix pylint
Cjkkkk Nov 18, 2019
f5d3191
fix
Cjkkkk Nov 18, 2019
a74fe20
fix
Cjkkkk Nov 18, 2019
bef64a2
fix straight through estimator
Cjkkkk Nov 18, 2019
c5ff101
fix docs
Cjkkkk Nov 18, 2019
0a2fc51
add ema for output
Cjkkkk Nov 18, 2019
32b74e4
add EMA
Cjkkkk Nov 18, 2019
34bb9af
change update_param to stateless
Cjkkkk Nov 18, 2019
38749e2
fix docs
Cjkkkk Nov 18, 2019
ed49531
fix docs
Cjkkkk Nov 18, 2019
118a1f7
Merge branch 'master' of https://github.com/microsoft/nni
Cjkkkk Nov 19, 2019
0fc4d6d
fix docstring & add quant bits length check
Cjkkkk Nov 20, 2019
f833411
Merge branch 'master' of https://github.com/microsoft/nni
Cjkkkk Nov 21, 2019
8f9d0f3
unit test for QAT
Cjkkkk Nov 21, 2019
dcf4483
unit test for QAT
Cjkkkk Nov 21, 2019
2825928
unit test for QAT
Cjkkkk Nov 21, 2019
25daf32
add modules_detection test for quantization framework
Cjkkkk Nov 21, 2019
9ecac14
add modules_detection test for quantization framework
Cjkkkk Nov 21, 2019
3437dcb
revert test
Cjkkkk Nov 21, 2019
ce2b194
fix doc string
Cjkkkk Nov 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions docs/en_US/Compressor/Quantizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,41 @@ In [Quantization and Training of Neural Networks for Efficient Integer-Arithmeti
### Usage
You can quantize your model to 8 bits with the code below before your training code.

Tensorflow code
```python
from nni.compressors.tensorflow import QAT_Quantizer
config_list = [{ 'q_bits': 8, 'op_types': ['default'] }]
quantizer = QAT_Quantizer(tf.get_default_graph(), config_list)
quantizer.compress()
```
PyTorch code
```python
from nni.compressors.torch import QAT_Quantizer
config_list = [{ 'q_bits': 8, 'op_types': ['default'] }]
model = Mnist()

config_list = [{
'quant_types': ['weight'],
'quant_bits': {
'weight': 8,
}, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types':['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 7000,
'op_types':['ReLU6']
}]
quantizer = QAT_Quantizer(model, config_list)
quantizer.compress()
```

You can view example for more information

#### User configuration for QAT Quantizer
* **q_bits:** This is to specify the q_bits operations to be quantized to


* **quant_types:** : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
* **quant_bits:** int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
when the type is int, all quantization types share same bits length
* **quant_start_step:** int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“where activation quantization ranges do not exclude a significant fraction of values”, don't understand this sentence, could you explain a little more?


### note
batch normalization folding is currently not supported.
***

## DoReFa Quantizer
Expand Down
98 changes: 98 additions & 0 deletions examples/model_compress/QAT_torch_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import QAT_Quantizer


class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()

def forward(self, x):
x = self.relu1(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = self.relu2(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(model, quantizer, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
quantizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))

def main():
torch.manual_seed(0)
device = torch.device('cpu')

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()

'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
configure_list = [{
'quant_types': ['weight'],
'quant_bits': {
'weight': 8,
}, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types':['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 7000,
'op_types':['ReLU6']
}]
quantizer = QAT_Quantizer(model, configure_list)
quantizer.compress()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
print('# Epoch {} #'.format(epoch))
train(model, quantizer, device, train_loader, optimizer)
test(model, device, test_loader)


if __name__ == '__main__':
main()
196 changes: 183 additions & 13 deletions src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,200 @@ def quantize_weight(self, weight, config, op_name, **kwargs):
return weight.div(scale).type(torch.int8).type(orig_type).mul(scale)


def update_ema(biased_ema, value, decay, step):
"""
calculate biased stat and unbiased stat in each step using exponential moving average method

Parameters
----------
biased_ema : float
previous stat value
value : float
current stat value
decay : float
the weight of previous stat value, larger means smoother curve
step : int
current step

Returns
-------
float, float
"""
biased_ema = biased_ema * decay + (1 - decay) * value
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return biased_ema, unbiased_ema

def update_quantization_param(bits, rmin, rmax):
"""
calculate the `zero_point` and `scale`.

Parameters
----------
bits : int
quantization bits length
rmin : float
min value of real value
rmax : float
max value of real value

Returns
-------
float, float
"""
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
rmin = min(rmin, 0)
rmax = max(rmax, 0)

# the min and max quantized values, as floating-point values
qmin = 0
qmax = (1 << bits) - 1
# First determine the scale.
scale = (rmax - rmin) / (qmax - qmin)

# Zero-point computation.
initial_zero_point = qmin - rmin / scale

# Now we need to nudge the zero point to be an integer
nudged_zero_point = 0
if initial_zero_point < qmin:
nudged_zero_point = qmin
elif initial_zero_point > qmax:
nudged_zero_point = qmax
else:
nudged_zero_point = torch.round(initial_zero_point)

return scale, nudged_zero_point


def get_bits_length(config, quant_type):
if isinstance(config["quant_bits"], int):
return config["quant_bits"]
else:
return config["quant_bits"].get(quant_type)


class QAT_Quantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
def __init__(self, model, config_list):
"""
config_list: supported keys:
- q_bits
Parameters
----------
layer : LayerInfo
the layer to quantize
config_list : list of dict
list of configurations for quantization
supported keys for dict:
- quant_types : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
- quant_bits : int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
when the type is int, all quantization types share same bits length
- quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list)
self.steps = 1
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", None)
layer.module.register_buffer("scale", None)
if "output" in config.get("quant_types", []):
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
layer.module.register_buffer('tracked_min', torch.zeros(1))
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))

def quantize_weight(self, weight, config, **kwargs):
if config['q_bits'] <= 1:
def _quantize(self, bits, op, real_val):
"""
quantize real value.

Parameters
----------
bits : int
quantization bits length
op : torch.nn.module
target module
real_val : float
real value to be quantized

Returns
-------
float
"""
transformed_val = op.zero_point + real_val / op.scale
qmin = 0
qmax = (1 << bits) - 1
clamped_val = torch.clamp(transformed_val, qmin, qmax)
quantized_val = torch.round(clamped_val)
return quantized_val

def _dequantize(self, op, quantized_val):
"""
dequantize quantized value.
Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
first quantize tensors then dequantize them. For more details, please refer to the paper.

Parameters
----------
op : torch.nn.Module
target module
quantized_val : float
quantized_val value to be dequantized

Returns
-------
float
"""
real_val = op.scale * (quantized_val - op.zero_point)
return real_val

def quantize_weight(self, weight, config, op, **kwargs):
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.steps:
return weight
a = torch.min(weight)
b = torch.max(weight)
n = pow(2, config['q_bits'])
scale = (b-a)/(n-1)
zero_point = a
out = torch.round((weight - zero_point)/scale)
out = out*scale + zero_point
orig_type = weight.dtype
return out.type(orig_type)
rmin, rmax = torch.min(weight), torch.max(weight)
op.scale, op.zero_point = update_quantization_param(weight_bits, rmin, rmax)
out = self._quantize(weight_bits, op, weight)
out = self._dequantize(op, out)
return out

def quantize_output(self, output, config, op, **kwargs):
output_bits = get_bits_length(config, 'output')
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.steps:
return output

current_min, current_max = torch.min(output), torch.max(output)
op.tracked_min_biased, op.tracked_min = update_ema(op.tracked_min_biased, current_min, op.ema_decay, self.steps)
op.tracked_max_biased, op.tracked_max = update_ema(op.tracked_max_biased, current_max, op.ema_decay, self.steps)
op.scale, op.zero_point = update_quantization_param(output_bits, op.tracked_min, op.tracked_max)
out = self._quantize(output_bits, op, output)
out = self._dequantize(op, out)
return out

def fold_bn(self, config, **kwargs):
# TODO simulate folded weight
pass

def step(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.steps += 1


class DoReFaQuantizer(Quantizer):
Expand Down
Loading