Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
export
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 committed Apr 18, 2021
1 parent 4ffeda9 commit 9c09bf0
Showing 1 changed file with 64 additions and 8 deletions.
72 changes: 64 additions & 8 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,11 @@ def __init__(self, model, config_list, optimizer=None):
self.quant_grad = QuantForward()
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
layer.module.register_parameter("zero_point", torch.nn.Parameter(torch.Tensor([0.0])))
layer.module.register_parameter("scale", torch.nn.Parameter(torch.Tensor([1.0])))
if "weight" in config.get("quant_types", []):
# todo: support per-channel quantization for weight since TensorRT it for conv weight
q_bit = get_bits_length(config, "weight")
layer.module.register_buffer('weight_bit', torch.Tensor([q_bit]))
qmax = 2 ** (q_bit - 1) - 1
qmin = -2 ** (q_bit - 1)
init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5)
Expand All @@ -607,12 +607,12 @@ def __init__(self, model, config_list, optimizer=None):
# todo: in the origin paper, the initial value of activation is calculated from first input batch
if "output" in config.get("quant_types", []):
q_bit = get_bits_length(config, "")
layer.module.register_buffer('activation_bit', torch.Tensor([q_bit]))
qmax = 2 ** (q_bit - 1) - 1
qmin = -2 ** (q_bit - 1)
layer.module.activation_qmax = qmax
layer.module.activation_qmin = qmin
# add zero_point and scale to optimizer since they are updated through the gradient
self.optimizer.add_param_group({"params": layer.module.zero_point})
# add scale to optimizer since they are updated through the gradient
self.optimizer.add_param_group({"params": layer.module.scale})

@staticmethod
Expand All @@ -633,13 +633,13 @@ def round_pass(x):
y_grad = x
return (y - y_grad).detach() + y_grad

def quantize(self, x, scale, zero_point, qmin, qmax):
def quantize(self, x, scale, qmin, qmax):
grad_scale_factor = 1.0 / ((qmax * x.numel()) ** 0.5)
scale = self.grad_scale(scale, grad_scale_factor)
x = x / scale + zero_point
x = x / scale
x = torch.clamp(x, qmin, qmax)
x = self.round_pass(x)
x = (x - zero_point) * scale
x = x * scale
return x

def quantize_weight(self, wrapper, **kwargs):
Expand All @@ -648,11 +648,67 @@ def quantize_weight(self, wrapper, **kwargs):
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
# bias
old_weight = module.old_weight
weight = self.quantize(old_weight, module.scale, module.zero_point, module.weight_qmin, module.weight_qmax)
weight = self.quantize(old_weight, module.scale, module.weight_qmin, module.weight_qmax)
module.weight = weight
return weight

def quantize_output(self, output, wrapper, **kwargs):
module = wrapper.module
output = self.quantize(output, module.scale, module.zero_point, module.activation_qmin, module.activation_qmax)
output = self.quantize(output, module.scale, module.activation_qmin, module.activation_qmax)
return output

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}

for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit') or hasattr(module, 'activation_bit'):
calibration_config[name] = {}
if hasattr(module, 'weight_bit'):
calibration_config[name]['weight_bit'] = int(module.weight_bit)
abs_max_weight = float(module.scale * module.weight_qmax)
calibration_config[name]['tracked_min_input'] = -abs_max_weight
calibration_config[name]['tracked_max_input'] = abs_max_weight
if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
abs_max_activation = float(module.scale * module.weight_qmax)
calibration_config[name]['tracked_min_activation'] = -abs_max_activation
calibration_config[name]['tracked_max_activation'] = abs_max_activation
self._del_simulated_attr(module)

self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path,
input_shape, device)

return calibration_config

def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)

0 comments on commit 9c09bf0

Please sign in to comment.