Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
chagne signature of quantize_input
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 committed Aug 6, 2021
1 parent a15eeab commit f27dc54
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 23 deletions.
26 changes: 7 additions & 19 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,7 @@ def validate_config(self, model, config_list):
def record(self, wrapper, quant_type, tensor):
name = wrapper.name
observer = self.all_observers[name][quant_type]
if isinstance(tensor, tuple):
# NB: This only works for single tensor
tensor = (t.cpu() for t in tensor)
observer(*tensor)
else:
observer(tensor.cpu())
observer(tensor.cpu())

def calculate_qparams(self, name, quant_type):
observer = self.all_observers[name][quant_type]
Expand All @@ -206,17 +201,14 @@ def _quantize(self, x, scale, zero_point, qmin, qmax):
x = (x - zero_point) * scale
return x

def quantize_input(self, *inputs, wrapper, **kwargs):
def quantize_input(self, inputs, wrapper, **kwargs):
if self.compressed:
module = wrapper.module
new_input = self._quantize(inputs[0],
inputs = self._quantize(inputs,
module.input_scale,
module.input_zero_point,
module.input_qmin,
module.input_qmax)
list_inp = list(inputs)
list_inp[0] = new_input
inputs = tuple(list_inp)
else:
self.record(wrapper, 'input', inputs)
return inputs
Expand Down Expand Up @@ -969,20 +961,16 @@ def quantize_output(self, output, wrapper, **kwargs):
output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax)
return output

def quantize_input(self, *inputs, wrapper, **kwargs):
# This is hacky since it is not recommended to modify a tuple
# NB: support layers with multi inputs
def quantize_input(self, inputs, wrapper, **kwargs):
module = wrapper.module
# initialize the scale
if self.bound_model.steps == 1:
qmax = module.input_qmax
init_oup_scale = inputs[0].data.detach().abs().mean() * 2 / (qmax ** 0.5)
init_oup_scale = inputs.data.detach().abs().mean() * 2 / (qmax ** 0.5)
module.input_scale.data = init_oup_scale

new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax)
list_inp = list(inputs)
list_inp[0] = new_input
return tuple(list_inp)
inputs = self.quantize(inputs, module.input_scale, module.input_qmin, module.input_qmax)
return inputs

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Expand Down
10 changes: 6 additions & 4 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,12 @@ def __init__(self, module, module_name, module_type, config, quantizer, bn_modul

def forward(self, *inputs):
if 'input' in self.config['quant_types']:
inputs = self.quantizer.quant_grad(
inputs,
assert len(inputs) == 1, "Quantization of input only supports ops with single input"
new_inp = self.quantizer.quant_grad(
inputs[0],
QuantType.QUANT_INPUT,
self)
inputs = (new_inp,)

if 'weight' in self.config['quant_types'] and _check_weight(self.module):
if self.bn_module is not None:
Expand Down Expand Up @@ -640,7 +642,7 @@ def quantize_output(self, output, wrapper, **kwargs):
"""
raise NotImplementedError('Quantizer must overload quantize_output()')

def quantize_input(self, *inputs, wrapper, **kwargs):
def quantize_input(self, inputs, wrapper, **kwargs):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Expand Down Expand Up @@ -912,7 +914,7 @@ def _check_bias(module):

def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs)
output = wrapper.quantizer.quantize_input(tensor, wrapper=wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
Expand Down

0 comments on commit f27dc54

Please sign in to comment.