-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add LSQ quantizer #3503
Add LSQ quantizer #3503
Changes from 11 commits
77e82d1
6eee549
0287994
d980728
727a58e
6f2b69c
ad36ea0
3a6d6d1
c20cb9f
f0922c4
0e9b8f0
369115c
b326b58
15402a9
e91ad68
c0263f1
671d2d9
4f41433
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
from nni.algorithms.compression.pytorch.quantization import LsqQuantizer | ||
|
||
|
||
class Mnist(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) | ||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) | ||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) | ||
self.fc2 = torch.nn.Linear(500, 10) | ||
self.relu1 = torch.nn.ReLU6() | ||
self.relu2 = torch.nn.ReLU6() | ||
self.relu3 = torch.nn.ReLU6() | ||
|
||
def forward(self, x): | ||
x = self.relu1(self.conv1(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = self.relu2(self.conv2(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = x.view(-1, 4 * 4 * 50) | ||
x = self.relu3(self.fc1(x)) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
def train(model, quantizer, device, train_loader, optimizer): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % 100 == 0: | ||
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) | ||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
test_loss /= len(test_loader.dataset) | ||
|
||
print('Loss: {} Accuracy: {}%)\n'.format( | ||
test_loss, 100 * correct / len(test_loader.dataset))) | ||
|
||
def main(): | ||
torch.manual_seed(0) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
train_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=True, download=True, transform=trans), | ||
batch_size=64, shuffle=True) | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=False, transform=trans), | ||
batch_size=1000, shuffle=True) | ||
|
||
model = Mnist() | ||
'''you can change this to DoReFaQuantizer to implement it | ||
DoReFaQuantizer(configure_list).compress(model) | ||
''' | ||
configure_list = [{ | ||
'quant_types': ['weight'], | ||
'quant_bits': { | ||
'weight': 8, | ||
}, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below. | ||
'op_types':['Conv2d', 'Linear'] | ||
}, { | ||
'quant_types': ['output'], | ||
'quant_bits': 8, | ||
'quant_start_step': 1000, | ||
'op_types':['ReLU6'] | ||
}] | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) | ||
quantizer = LsqQuantizer(model, configure_list, optimizer) | ||
quantizer.compress() | ||
|
||
model.to(device) | ||
for epoch in range(40): | ||
print('# Epoch {} #'.format(epoch)) | ||
train(model, quantizer, device, train_loader, optimizer) | ||
test(model, device, test_loader) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,9 @@ | |
import torch | ||
from schema import Schema, And, Or, Optional | ||
from nni.compression.pytorch.utils.config_validation import CompressorSchema | ||
from nni.compression.pytorch.compressor import Quantizer, QuantGrad, QuantType | ||
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType | ||
|
||
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer'] | ||
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -59,7 +59,7 @@ def update_ema(biased_ema, value, decay): | |
float, float | ||
""" | ||
biased_ema = biased_ema * decay + (1 - decay) * value | ||
return biased_ema | ||
return biased_ema | ||
|
||
|
||
def update_quantization_param(bits, rmin, rmax): | ||
|
@@ -146,7 +146,7 @@ def __init__(self, model, config_list, optimizer=None): | |
types of nn.module you want to apply quantization, eg. 'Conv2d' | ||
""" | ||
super().__init__(model, config_list, optimizer) | ||
self.quant_grad = QATGrad | ||
self.quant_grad = QATGrad.apply | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we have to move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it is for avoiding STE in LSQ quantizer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is aimed at unifying the framework of quantizers with customized gradient and quantizers with auto-grad gradient. Also, use |
||
modules_to_compress = self.get_modules_to_compress() | ||
self.bound_model.register_buffer("steps", torch.Tensor([1])) | ||
for layer, config in modules_to_compress: | ||
|
@@ -474,7 +474,7 @@ class BNNQuantizer(Quantizer): | |
|
||
def __init__(self, model, config_list, optimizer=None): | ||
super().__init__(model, config_list, optimizer) | ||
self.quant_grad = ClipGrad | ||
self.quant_grad = ClipGrad.apply | ||
modules_to_compress = self.get_modules_to_compress() | ||
for layer, config in modules_to_compress: | ||
if "weight" in config.get("quant_types", []): | ||
|
@@ -559,4 +559,170 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ | |
|
||
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) | ||
|
||
return calibration_config | ||
return calibration_config | ||
|
||
|
||
class LsqQuantizer(Quantizer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add docstring as the other Quantizers, especially for parameters and return. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
"""Quantizer defined in: | ||
Learned Step Size Quantization (ICLR 2020) | ||
https://arxiv.org/pdf/1902.08153.pdf | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please align There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
def __init__(self, model, config_list, optimizer=None): | ||
""" | ||
Parameters | ||
---------- | ||
model : torch.nn.Module | ||
the model to be quantized | ||
config_list : list of dict | ||
list of configurations for quantization | ||
supported keys for dict: | ||
- quant_types : list of string | ||
type of quantization you want to apply, currently support 'weight', 'input', 'output' | ||
- quant_bits : int or dict of {str : int} | ||
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. {'weight', 8} -> {'weight': 8} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
when the type is int, all quantization types share same bits length | ||
- quant_start_step : int | ||
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable | ||
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 | ||
- op_types : list of string | ||
types of nn.module you want to apply quantization, eg. 'Conv2d' | ||
""" | ||
super().__init__(model, config_list, optimizer) | ||
self.quant_grad = QuantForward() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we keep the original forward and backward structure, the Lsq can forward as usual and backward by STE. In this way, will it be anything wrong? May be have something to do with the update of scale and zeropoint. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There will not be anything wrong if the gradients are handled carefully. However, there exists one major limitation for the origin framework, that is, we must customize all gradients for all learnable parameters. If the gradient-based algorithms become complex, it will be troubling and error-prone to do the customization. In this situation, I think using the auto-grad system to determine the gradient is more convenient for users. |
||
modules_to_compress = self.get_modules_to_compress() | ||
self.bound_model.register_buffer("steps", torch.Tensor([1])) | ||
for layer, config in modules_to_compress: | ||
layer.module.register_parameter("scale", torch.nn.Parameter(torch.Tensor([1.0]))) | ||
if "weight" in config.get("quant_types", []): | ||
# todo: support per-channel quantization for weight since TensorRT it for conv weight | ||
q_bit = get_bits_length(config, "weight") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In current implementation, we only support single bit quantization in LsqQuantizer? Can we support mixed precision right now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that mixed precision of quantization is supported in this implementation since each layer has its own
|
||
layer.module.register_buffer('weight_bit', torch.Tensor([q_bit])) | ||
qmax = 2 ** (q_bit - 1) - 1 | ||
qmin = -2 ** (q_bit - 1) | ||
init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5) | ||
layer.module.scale = torch.nn.Parameter(init_weight_scale) | ||
layer.module.weight_qmax = qmax | ||
layer.module.weight_qmin = qmin | ||
|
||
if "output" in config.get("quant_types", []): | ||
# scale of activation will be initialized using the first batch data | ||
q_bit = get_bits_length(config, "output") | ||
layer.module.register_buffer('activation_bit', torch.Tensor([q_bit])) | ||
qmax = 2 ** (q_bit - 1) - 1 | ||
qmin = -2 ** (q_bit - 1) | ||
layer.module.activation_qmax = qmax | ||
layer.module.activation_qmin = qmin | ||
# add scale to optimizer since they are updated through the gradient | ||
self.optimizer.add_param_group({"params": layer.module.scale}) | ||
|
||
@staticmethod | ||
def grad_scale(x, scale): | ||
""" | ||
Used to scale the gradient | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommend explaining this function in detail since both of reviewers were confused during reviewing this part. Whatever, I think this function is also part of key implementation of LSQ which can helps others understand the insight of this algorithm. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
""" | ||
y = x | ||
y_grad = x * scale | ||
return (y - y_grad).detach() + y_grad | ||
|
||
@staticmethod | ||
def round_pass(x): | ||
""" | ||
A simple way to execute `round` operation with grad set to 1 | ||
""" | ||
y = x.round() | ||
y_grad = x | ||
return (y - y_grad).detach() + y_grad | ||
|
||
def quantize(self, x, scale, qmin, qmax): | ||
grad_scale_factor = 1.0 / ((qmax * x.numel()) ** 0.5) | ||
scale = self.grad_scale(scale, grad_scale_factor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A little confused about the name of value and function. Can we polish naming here or in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The names of functions and variables are the same as those defined in the paper. |
||
x = x / scale | ||
x = torch.clamp(x, qmin, qmax) | ||
x = self.round_pass(x) | ||
x = x * scale | ||
return x | ||
|
||
def quantize_weight(self, wrapper, **kwargs): | ||
module = wrapper.module | ||
|
||
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize | ||
# bias | ||
old_weight = module.old_weight | ||
weight = self.quantize(old_weight, module.scale, module.weight_qmin, module.weight_qmax) | ||
module.weight = weight | ||
return weight | ||
|
||
def quantize_output(self, output, wrapper, **kwargs): | ||
module = wrapper.module | ||
|
||
# initialize the scale | ||
if self.bound_model.steps == 1: | ||
qmax = module.activation_qmax | ||
init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5) | ||
module.scale.data = init_oup_scale | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that weight and activation use the same scale in single module which means weight and activation have the same rescale parameter, and the value of scale will update by the gradient of weight and activation simultaneously. What consequence would be caused if we quantized both weight and activation of the same layer? Would it cause something wrong? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes you are right:) Now each layer will construct input_scale/weight_scale/output_sclae according to the config setting. |
||
|
||
output = self.quantize(output, module.scale, module.activation_qmin, module.activation_qmax) | ||
return output | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this quantization algorithm support exporting model and related quantization parameters? If yes, maybe we can consider adding function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will check it out |
||
|
||
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): | ||
""" | ||
Export quantized model weights and calibration parameters(optional) | ||
|
||
Parameters | ||
---------- | ||
model_path : str | ||
path to save quantized model weight | ||
calibration_path : str | ||
(optional) path to save quantize parameters after calibration | ||
onnx_path : str | ||
(optional) path to save onnx model | ||
input_shape : list or tuple | ||
input shape to onnx model | ||
device : torch.device | ||
device of the model, used to place the dummy input tensor for exporting onnx file. | ||
the tensor is placed on cpu if ```device``` is None | ||
|
||
Returns | ||
------- | ||
Dict | ||
""" | ||
assert model_path is not None, 'model_path must be specified' | ||
self._unwrap_model() | ||
calibration_config = {} | ||
|
||
for name, module in self.bound_model.named_modules(): | ||
if hasattr(module, 'weight_bit') or hasattr(module, 'activation_bit'): | ||
calibration_config[name] = {} | ||
if hasattr(module, 'weight_bit'): | ||
calibration_config[name]['weight_bit'] = int(module.weight_bit) | ||
abs_max_weight = float(module.scale * module.weight_qmax) | ||
calibration_config[name]['tracked_min_input'] = -abs_max_weight | ||
calibration_config[name]['tracked_max_input'] = abs_max_weight | ||
if hasattr(module, 'activation_bit'): | ||
calibration_config[name]['activation_bit'] = int(module.activation_bit) | ||
abs_max_activation = float(module.scale * module.activation_qmax) | ||
calibration_config[name]['tracked_min_activation'] = -abs_max_activation | ||
calibration_config[name]['tracked_max_activation'] = abs_max_activation | ||
self._del_simulated_attr(module) | ||
|
||
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, | ||
input_shape, device) | ||
|
||
return calibration_config | ||
|
||
def _del_simulated_attr(self, module): | ||
""" | ||
delete redundant parameters in quantize module | ||
""" | ||
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \ | ||
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] | ||
for attr in del_attr_list: | ||
if hasattr(module, attr): | ||
delattr(module, attr) | ||
|
||
def step_with_optimizer(self): | ||
""" | ||
override `compressor` `step` method, quantization only happens after certain number of steps | ||
""" | ||
self.bound_model.steps += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment can be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done