Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
use first batch data to initialize activation scale
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 committed May 14, 2021
1 parent f0922c4 commit 0e9b8f0
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantForward()
modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress:
layer.module.register_parameter("scale", torch.nn.Parameter(torch.Tensor([1.0])))
if "weight" in config.get("quant_types", []):
Expand All @@ -604,9 +605,9 @@ def __init__(self, model, config_list, optimizer=None):
layer.module.weight_qmax = qmax
layer.module.weight_qmin = qmin

# todo: in the origin paper, the initial value of activation is calculated from first input batch
if "output" in config.get("quant_types", []):
q_bit = get_bits_length(config, "")
# scale of activation will be initialized using the first batch data
q_bit = get_bits_length(config, "output")
layer.module.register_buffer('activation_bit', torch.Tensor([q_bit]))
qmax = 2 ** (q_bit - 1) - 1
qmin = -2 ** (q_bit - 1)
Expand Down Expand Up @@ -654,6 +655,13 @@ def quantize_weight(self, wrapper, **kwargs):

def quantize_output(self, output, wrapper, **kwargs):
module = wrapper.module

# initialize the scale
if self.bound_model.steps == 1:
qmax = module.activation_qmax
init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5)
module.scale.data = init_oup_scale

output = self.quantize(output, module.scale, module.activation_qmin, module.activation_qmax)
return output

Expand Down Expand Up @@ -693,7 +701,7 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
calibration_config[name]['tracked_max_input'] = abs_max_weight
if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
abs_max_activation = float(module.scale * module.weight_qmax)
abs_max_activation = float(module.scale * module.activation_qmax)
calibration_config[name]['tracked_min_activation'] = -abs_max_activation
calibration_config[name]['tracked_max_activation'] = abs_max_activation
self._del_simulated_attr(module)
Expand All @@ -712,3 +720,9 @@ def _del_simulated_attr(self, module):
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)

def step_with_optimizer(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.bound_model.steps += 1

0 comments on commit 0e9b8f0

Please sign in to comment.