-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add LSQ quantizer #3503
Add LSQ quantizer #3503
Changes from 1 commit
77e82d1
6eee549
0287994
d980728
727a58e
6f2b69c
ad36ea0
3a6d6d1
c20cb9f
f0922c4
0e9b8f0
369115c
b326b58
15402a9
e91ad68
c0263f1
671d2d9
4f41433
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -593,28 +593,43 @@ def __init__(self, model, config_list, optimizer=None): | |
modules_to_compress = self.get_modules_to_compress() | ||
self.bound_model.register_buffer("steps", torch.Tensor([1])) | ||
for layer, config in modules_to_compress: | ||
layer.module.register_parameter("scale", torch.nn.Parameter(torch.Tensor([1.0]))) | ||
if "weight" in config.get("quant_types", []): | ||
# todo: support per-channel quantization for weight since TensorRT it for conv weight | ||
layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0]))) | ||
# todo: support per-channel quantization for weight since TensorRT use it for conv weight | ||
q_bit = get_bits_length(config, "weight") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In current implementation, we only support single bit quantization in LsqQuantizer? Can we support mixed precision right now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that mixed precision of quantization is supported in this implementation since each layer has its own
|
||
layer.module.register_buffer('weight_bit', torch.Tensor([q_bit])) | ||
qmax = 2 ** (q_bit - 1) - 1 | ||
qmin = -2 ** (q_bit - 1) | ||
init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5) | ||
layer.module.scale = torch.nn.Parameter(init_weight_scale) | ||
layer.module.weight_scale = torch.nn.Parameter(init_weight_scale) | ||
layer.module.weight_qmax = qmax | ||
layer.module.weight_qmin = qmin | ||
|
||
self.optimizer.add_param_group({"params": layer.module.weight_scale}) | ||
|
||
if "output" in config.get("quant_types", []): | ||
# scale of activation will be initialized using the first batch data | ||
layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0]))) | ||
q_bit = get_bits_length(config, "output") | ||
layer.module.register_buffer('activation_bit', torch.Tensor([q_bit])) | ||
layer.module.register_buffer('output_bit', torch.Tensor([q_bit])) | ||
qmax = 2 ** (q_bit - 1) - 1 | ||
qmin = -2 ** (q_bit - 1) | ||
layer.module.output_qmax = qmax | ||
layer.module.output_qmin = qmin | ||
|
||
self.optimizer.add_param_group({"params": layer.module.output_scale}) | ||
|
||
if "input" in config.get("quant_types", []): | ||
# scale of activation will be initialized using the first batch data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. activation -> input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0]))) | ||
q_bit = get_bits_length(config, "input") | ||
layer.module.register_buffer('input_bit', torch.Tensor([q_bit])) | ||
qmax = 2 ** (q_bit - 1) - 1 | ||
qmin = -2 ** (q_bit - 1) | ||
layer.module.activation_qmax = qmax | ||
layer.module.activation_qmin = qmin | ||
# add scale to optimizer since they are updated through the gradient | ||
self.optimizer.add_param_group({"params": layer.module.scale}) | ||
layer.module.input_qmax = qmax | ||
layer.module.input_qmin = qmin | ||
|
||
self.optimizer.add_param_group({"params": layer.module.input_scale}) | ||
|
||
@staticmethod | ||
def grad_scale(x, scale): | ||
|
@@ -649,7 +664,7 @@ def quantize_weight(self, wrapper, **kwargs): | |
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize | ||
# bias | ||
old_weight = module.old_weight | ||
weight = self.quantize(old_weight, module.scale, module.weight_qmin, module.weight_qmax) | ||
weight = self.quantize(old_weight, module.weight_scale, module.weight_qmin, module.weight_qmax) | ||
module.weight = weight | ||
return weight | ||
|
||
|
@@ -658,13 +673,28 @@ def quantize_output(self, output, wrapper, **kwargs): | |
|
||
# initialize the scale | ||
if self.bound_model.steps == 1: | ||
qmax = module.activation_qmax | ||
qmax = module.output_qmax | ||
init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5) | ||
module.scale.data = init_oup_scale | ||
module.output_scale.data = init_oup_scale | ||
|
||
output = self.quantize(output, module.scale, module.activation_qmin, module.activation_qmax) | ||
output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax) | ||
return output | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this quantization algorithm support exporting model and related quantization parameters? If yes, maybe we can consider adding function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will check it out |
||
|
||
def quantize_input(self, *inputs, wrapper, **kwargs): | ||
# This is hacky since it is not recommended to modify a tuple | ||
# NB: support layers with multi inputs | ||
module = wrapper.module | ||
# initialize the scale | ||
if self.bound_model.steps == 1: | ||
qmax = module.input_qmax | ||
init_oup_scale = inputs[0].data.detach().abs().mean() * 2 / (qmax ** 0.5) | ||
module.input_scale.data = init_oup_scale | ||
|
||
new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax) | ||
list_inp = list(inputs) | ||
list_inp[0] = new_input | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we only quantize the first input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that currently the quantization framework only supports layers with single input (see here, so is the trt backend, see here ). So current implementation does not support layers with multi inputs. It may be a better choice to modify the lsq quantizer to support layers with multi inputs after the framework supports it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it, it is reasonable |
||
return tuple(list_inp) | ||
|
||
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): | ||
""" | ||
Export quantized model weights and calibration parameters(optional) | ||
|
@@ -692,18 +722,18 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ | |
calibration_config = {} | ||
|
||
for name, module in self.bound_model.named_modules(): | ||
if hasattr(module, 'weight_bit') or hasattr(module, 'activation_bit'): | ||
if hasattr(module, 'input_bit') or hasattr(module, 'output_bit'): | ||
calibration_config[name] = {} | ||
if hasattr(module, 'weight_bit'): | ||
calibration_config[name]['weight_bit'] = int(module.weight_bit) | ||
abs_max_weight = float(module.scale * module.weight_qmax) | ||
calibration_config[name]['tracked_min_input'] = -abs_max_weight | ||
calibration_config[name]['tracked_max_input'] = abs_max_weight | ||
if hasattr(module, 'activation_bit'): | ||
calibration_config[name]['activation_bit'] = int(module.activation_bit) | ||
abs_max_activation = float(module.scale * module.activation_qmax) | ||
calibration_config[name]['tracked_min_activation'] = -abs_max_activation | ||
calibration_config[name]['tracked_max_activation'] = abs_max_activation | ||
if hasattr(module, 'input_bit'): | ||
calibration_config[name]['weight_bit'] = int(module.input_bit) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why assigning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to here, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, we choose to record range of input tensor during the process of quantizing weight in the algorithm QAT. The reason why we handle it in this way is the requirement of integration with TensorRT which needs input tensor's dynamic range when setting layer precision to 8bit. So we record input dynamic range as here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. How about changing the codes like:
|
||
abs_max_input = float(module.input_scale * module.input_qmax) | ||
calibration_config[name]['tracked_min_input'] = -abs_max_input | ||
calibration_config[name]['tracked_max_input'] = abs_max_input | ||
if hasattr(module, 'output_bit'): | ||
calibration_config[name]['activation_bit'] = int(module.output_bit) | ||
abs_max_output = float(module.output_scale * module.output_qmax) | ||
calibration_config[name]['tracked_min_activation'] = -abs_max_output | ||
calibration_config[name]['tracked_max_activation'] = abs_max_output | ||
self._del_simulated_attr(module) | ||
|
||
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, | ||
|
@@ -715,8 +745,8 @@ def _del_simulated_attr(self, module): | |
""" | ||
delete redundant parameters in quantize module | ||
""" | ||
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \ | ||
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] | ||
del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_activation', \ | ||
'tracked_max_activation', 'output_scale', 'input_scale', 'weight_scale','weight_bit', 'output_bit', 'input_bit'] | ||
for attr in del_attr_list: | ||
if hasattr(module, attr): | ||
delattr(module, attr) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment can be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done