-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add post training observer_quantizer #3915
Changes from all commits
c082148
11a40c3
b717e53
21bfa52
d124415
656c63d
8aef2c2
455cc12
ab32e8e
28be4fe
455340a
b67bda5
365d812
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
from nni.algorithms.compression.pytorch.quantization import ObserverQuantizer | ||
import sys | ||
sys.path.append('../models') | ||
from mnist.naive import NaiveModel | ||
|
||
|
||
def train(model, device, train_loader, optimizer): | ||
model.to(device) | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % 100 == 0: | ||
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) | ||
|
||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
test_loss /= len(test_loader.dataset) | ||
|
||
print('Loss: {} Accuracy: {}%)\n'.format( | ||
test_loss, 100 * correct / len(test_loader.dataset))) | ||
|
||
|
||
def calibration(model, device, test_loader): | ||
model.eval() | ||
with torch.no_grad(): | ||
for data, _ in test_loader: | ||
data = data.to(device) | ||
model(data) | ||
|
||
|
||
def main(): | ||
torch.manual_seed(0) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
train_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=True, download=True, transform=trans), | ||
batch_size=64, shuffle=True) | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=False, transform=trans), | ||
batch_size=1000, shuffle=True) | ||
|
||
model = NaiveModel() | ||
configure_list = [{ | ||
'quant_types': ['weight', 'input'], | ||
'quant_bits': {'weight': 8, 'input': 8}, | ||
'op_names': ['conv1'], | ||
}, { | ||
'quant_types': ['output'], | ||
'quant_bits': {'output': 8, }, | ||
'op_names': ['relu1'], | ||
}, { | ||
'quant_types': ['weight', 'input'], | ||
'quant_bits': {'weight': 8, 'input': 8}, | ||
'op_names': ['conv2'], | ||
}, { | ||
'quant_types': ['output'], | ||
'quant_bits': {'output': 8}, | ||
'op_names': ['relu2'], | ||
}, { | ||
'quant_types': ['output'], | ||
'quant_bits': {'output': 8}, | ||
'op_names': ['max_pool2'], | ||
} | ||
] | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) | ||
|
||
# Train the model to get a baseline performance | ||
for epoch in range(5): | ||
print('# Epoch {} #'.format(epoch)) | ||
train(model, device, train_loader, optimizer) | ||
|
||
test(model, device, test_loader) | ||
|
||
# Construct the ObserverQuantizer. Note that currently ObserverQuantizer only works | ||
# in evaluation mode. | ||
quantizer = ObserverQuantizer(model.eval(), configure_list, optimizer) | ||
# Use the test data set to do calibration, this will not change the model parameters | ||
calibration(model, device, test_loader) | ||
# obtain the quantization information and switch the model to "accuracy verification" mode | ||
quantizer.compress() | ||
|
||
# measure the accuracy of the quantized model. | ||
test(model, device, test_loader) | ||
|
||
model_path = "mnist_model.pth" | ||
calibration_path = "mnist_calibration.pth" | ||
calibration_config = quantizer.export_model(model_path, calibration_path) | ||
print("calibration_config: ", calibration_config) | ||
|
||
# For now the quantization settings of ObserverQuantizer does not match the TensorRT, | ||
# so TensorRT conversion are not supported | ||
# current settings: | ||
# weight : per_tensor_symmetric, qint8 | ||
# activation : per_tensor_affine, quint8, reduce_range=True | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from torch.quantization import default_weight_observer, default_histogram_observer | ||
|
||
__all__ = ["default_weight_observer", "default_histogram_observer"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,15 @@ | |
|
||
import logging | ||
import copy | ||
from collections import defaultdict | ||
import torch | ||
from schema import Schema, And, Or, Optional | ||
from nni.compression.pytorch.utils.config_validation import QuantizerSchema | ||
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType | ||
|
||
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] | ||
from .observers import default_weight_observer, default_histogram_observer | ||
|
||
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer', 'ObserverQuantizer'] | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -120,6 +123,231 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma | |
return grad_output | ||
|
||
|
||
class ObserverQuantizer(Quantizer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a ut for this quantizer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
"""This quantizer uses observers to record weight/activation statistics to get quantization information. | ||
The whole process can be divided into three steps: | ||
1. It will register observers to the place where quantization would happen (just like registering hooks). | ||
2. The observers would record tensors' statistics during calibration. | ||
3. Scale & zero point would be obtained after calibration. | ||
Note that the observer type, tensor dtype and quantization qscheme are hard coded for now. Their customization | ||
are under development and will be ready soon. | ||
""" | ||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
super().__init__(model, config_list, optimizer) | ||
# NOTE: this quantizer is experimental for now. The dtype and qscheme of quantization | ||
# is hard-coded. | ||
# TODO: | ||
# 1. support dtype and qscheme customization through config_list. Current settings: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if dtype and qscheme can be applied on each layer separately, then it is better to support them in config_list. otherwise, it is better to support them as quantizer's initialization argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree, and I think it should be supported in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# weight observer : per_tensor_symmetric, qint8 | ||
# activation observer : per_tensor_affine, quint8, reduce_range=True | ||
# 2. add more kinds of observers, such as Kullback-Leibler divergence. | ||
# 3. add batch normalization folding | ||
assert not model.training, "Currently the observer quantizer only works in evaluation mode." | ||
self.quant_grad = QuantForward() | ||
self.device = next(model.parameters()).device | ||
modules_to_compress = self.get_modules_to_compress() | ||
all_observers = defaultdict(dict) | ||
weight_q_min, weight_q_max = -127, 127 | ||
activation_q_min, activation_q_max = 0, 127 # reduce_range is set to True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we set quantized activation range to (0, 127) instead of (0,255)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By default, activation observer's |
||
self.compressed = False | ||
|
||
for layer, config in modules_to_compress: | ||
layer_name = layer.name | ||
module = layer.module | ||
if "weight" in config.get("quant_types", []): | ||
all_observers[layer_name]["weight"] = default_weight_observer() | ||
setattr(module, "weight_qmax", weight_q_max) | ||
setattr(module, "weight_qmin", weight_q_min) | ||
if "input" in config.get("quant_types", []): | ||
all_observers[layer_name]["input"] = default_histogram_observer() | ||
setattr(module, "input_qmax", activation_q_max) | ||
setattr(module, "input_qmin", activation_q_min) | ||
if "output" in config.get("quant_types", []): | ||
all_observers[layer_name]["output"] = default_histogram_observer() | ||
setattr(module, "output_qmax", activation_q_max) | ||
setattr(module, "output_qmin", activation_q_min) | ||
self.all_observers = all_observers | ||
self.bound_model.to(self.device) | ||
|
||
def validate_config(self, model, config_list): | ||
schema = QuantizerSchema([{ | ||
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]), | ||
Optional('quant_bits'): Or(And(int, lambda n: n == 8), Schema({ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For post-training quantization, we support int8 right now. If we want to support all bit type or mixed precision, is there any obstacle? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that there would not be major obstacles. The reason why we only support int8 right now is that PyTorch quantization observers only support 8 bit quantization (see here ). To support them, we should extend/customize the observers to support all bit type. |
||
Optional('weight'): And(int, lambda n: n == 8), | ||
Optional('output'): And(int, lambda n: n == 8), | ||
Optional('input'): And(int, lambda n: n == 8), | ||
})), | ||
Optional('op_types'): [str], | ||
Optional('op_names'): [str] | ||
}], model, logger) | ||
|
||
schema.validate(config_list) | ||
|
||
def record(self, wrapper, quant_type, tensor): | ||
name = wrapper.name | ||
observer = self.all_observers[name][quant_type] | ||
if isinstance(tensor, tuple): | ||
# NB: This only works for single tensor | ||
tensor = (t.cpu() for t in tensor) | ||
observer(*tensor) | ||
else: | ||
observer(tensor.cpu()) | ||
|
||
def calculate_qparams(self, name, quant_type): | ||
observer = self.all_observers[name][quant_type] | ||
scale, zero_point = observer.calculate_qparams() | ||
return scale, zero_point | ||
|
||
def _quantize(self, x, scale, zero_point, qmin, qmax): | ||
x = x / scale + zero_point | ||
x = torch.clamp(x, qmin, qmax) | ||
x = torch.round(x) | ||
x = (x - zero_point) * scale | ||
return x | ||
|
||
def quantize_input(self, *inputs, wrapper, **kwargs): | ||
if self.compressed: | ||
module = wrapper.module | ||
new_input = self._quantize(inputs[0], | ||
module.input_scale, | ||
module.input_zero_point, | ||
module.input_qmin, | ||
module.input_qmax) | ||
list_inp = list(inputs) | ||
list_inp[0] = new_input | ||
inputs = tuple(list_inp) | ||
else: | ||
self.record(wrapper, 'input', inputs) | ||
return inputs | ||
|
||
def quantize_weight(self, wrapper, **kwargs): | ||
# If ObserverQuantizer.compress is executed, the weight will be set to | ||
# the Pseudo-quantized one. So there is no need to quantize it | ||
if self.compressed: | ||
return | ||
|
||
module = wrapper.module | ||
old_weight = module.weight | ||
self.record(wrapper, 'weight', old_weight) | ||
|
||
def quantize_output(self, output, wrapper, **kwargs): | ||
if self.compressed: | ||
module = wrapper.module | ||
new_output = self._quantize(output, | ||
module.output_scale, | ||
module.output_zero_point, | ||
module.output_qmin, | ||
module.output_qmax) | ||
else: | ||
self.record(wrapper, 'output', output) | ||
new_output = output | ||
return new_output | ||
|
||
def compress(self): | ||
""" | ||
Calculate quantization information of each tensor. Note that the inference of | ||
the compressed model will no longer update the corresponding. Instead, the quantization | ||
process will be simulated, which is used to test the accuracy of the quantization. | ||
""" | ||
modules_to_compress = self.get_modules_to_compress() | ||
for layer, config in modules_to_compress: | ||
module = layer.module | ||
if "weight" in config.get("quant_types", []): | ||
scale, zero_point = self.calculate_qparams(layer.name, 'weight') | ||
module.register_buffer('weight_scale', scale.to(self.device)) | ||
module.register_buffer('weight_zero_point', zero_point.to(self.device)) | ||
weight = module.weight | ||
quantized_weight = self._quantize(weight, | ||
module.weight_scale, | ||
module.weight_zero_point, | ||
module.weight_qmin, | ||
module.weight_qmax) | ||
delattr(module, 'weight') | ||
module.register_parameter('weight', torch.nn.Parameter(quantized_weight)) | ||
if "input" in config.get("quant_types", []): | ||
scale, zero_point = self.calculate_qparams(layer.name, 'input') | ||
module.register_buffer('input_scale', scale.to(self.device)) | ||
module.register_buffer('input_zero_point', zero_point.to(self.device)) | ||
if "output" in config.get("quant_types", []): | ||
scale, zero_point = self.calculate_qparams(layer.name, 'output') | ||
module.register_buffer('output_scale', scale.to(self.device)) | ||
module.register_buffer('output_zero_point', zero_point.to(self.device)) | ||
self.compressed = True | ||
super().compress() | ||
|
||
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): | ||
""" | ||
Export quantized model weights and calibration parameters(optional) | ||
|
||
Parameters | ||
---------- | ||
model_path : str | ||
path to save quantized model weight | ||
calibration_path : str | ||
(optional) path to save quantize parameters after calibration | ||
onnx_path : str | ||
(optional) path to save onnx model | ||
input_shape : list or tuple | ||
input shape to onnx model | ||
device : torch.device | ||
device of the model, used to place the dummy input tensor for exporting onnx file. | ||
the tensor is placed on cpu if ```device``` is None | ||
|
||
Returns | ||
------- | ||
Dict | ||
""" | ||
assert model_path is not None, 'model_path must be specified' | ||
self._unwrap_model() | ||
calibration_config = {} | ||
|
||
for name, module in self.bound_model.named_modules(): | ||
if hasattr(module, 'weight_scale') or hasattr(module, 'input_scale') or hasattr(module, 'output_scale'): | ||
calibration_config[name] = {} | ||
if hasattr(module, 'weight_scale'): | ||
calibration_config[name]['weight_bit'] = 8 | ||
val = float(module.weight_scale * module.weight_qmax) | ||
calibration_config[name]['tracked_max_weight'] = val | ||
calibration_config[name]['tracked_min_weight'] = -val | ||
calibration_config[name]['tracked_weight_qmin'] = -127 | ||
calibration_config[name]['tracked_weight_qmax'] = 127 | ||
# refactor these magic numbers when customizations of dtype and qscheme are ready. | ||
if hasattr(module, 'input_scale'): | ||
calibration_config[name]['input_bit'] = 8 | ||
max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point)) | ||
min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point)) | ||
calibration_config[name]['tracked_min_input'] = min_input | ||
calibration_config[name]['tracked_max_input'] = max_input | ||
calibration_config[name]['tracked_input_qmin'] = 0 | ||
calibration_config[name]['tracked_input_qmax'] = 127 | ||
if hasattr(module, 'output_scale'): | ||
calibration_config[name]['activation_bit'] = 8 | ||
max_input = float(module.output_scale * (module.output_qmax - module.output_zero_point)) | ||
min_input = float(module.output_scale * (module.output_qmin - module.output_zero_point)) | ||
calibration_config[name]['tracked_min_activation'] = min_input | ||
calibration_config[name]['tracked_max_activation'] = max_input | ||
calibration_config[name]['tracked_activation_qmin'] = 0 | ||
calibration_config[name]['tracked_activation_qmax'] = 127 | ||
self._del_simulated_attr(module) | ||
|
||
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, | ||
input_shape, device) | ||
|
||
return calibration_config | ||
|
||
def _del_simulated_attr(self, module): | ||
""" | ||
delete redundant parameters in quantize module | ||
""" | ||
del_attr_list = ['old_weight', 'steps', 'weight_qmax', 'weight_qmin', 'input_qmax', 'input_qmin', | ||
'output_qmax', 'output_qmin', 'weight_scale', 'weight_zero_point', 'input_scale', | ||
'input_zero_point', 'output_scale', 'output_zero_point'] | ||
for attr in del_attr_list: | ||
if hasattr(module, attr): | ||
delattr(module, attr) | ||
|
||
|
||
class QAT_Quantizer(Quantizer): | ||
"""Quantizer defined in: | ||
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we support TensorRT currently? Is it for reasons that some runtime errors may be raised or the result will not be aligned between simulated quantization and TensorRT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the
dtype
andqscheme
of PyTorch default observer are different with that in TensorRT. For example, TensorRT uses per_tensor_symmetric with uint8 for activation, PyTorch uses per_tensor_affine with quint8 for activation.When customization of
dtype
andqscheme
is ready, we can support TensorRT.