Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 committed Jul 26, 2021
1 parent d124415 commit 656c63d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/model_compress/quantization/observer_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test(model, device, test_loader):
def calibration(model, device, test_loader):
model.eval()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
for data, _ in test_loader:
data = data.to(device)
model(data)


Expand Down
19 changes: 15 additions & 4 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,6 @@ def compress(self):
scale, zero_point = self.calculate_qparams(layer.name, 'weight')
module.register_buffer('weight_scale', scale.to(self.device))
module.register_buffer('weight_zero_point', zero_point.to(self.device))
# todo: recover old_weight to weight, because the compressed
# model may be further finetuned.
if "input" in config.get("quant_types", []):
scale, zero_point = self.calculate_qparams(layer.name, 'input')
module.register_buffer('input_scale', scale.to(self.device))
Expand Down Expand Up @@ -301,11 +299,24 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
calibration_config = {}

for name, module in self.bound_model.named_modules():
if hasattr(module, 'input_scale') or hasattr(module, 'output_scale'):
if hasattr(module, 'weight_scale') or hasattr(module, 'input_scale') or hasattr(module, 'output_scale'):
calibration_config[name] = {}
if hasattr(module, 'weight_scale'):
calibration_config[name]['weight_bit'] = 8
val = float(module.weight_scale * module.weight_qmax)
calibration_config[name]['tracked_min_weight'] = val
calibration_config[name]['tracked_max_weight'] = -val
calibration_config[name]['tracked_weight_qmin'] = -127
calibration_config[name]['tracked_weight_qmax'] = 127
actual_weight = getattr(module, 'old_weight', None)
if actual_weight is None:
logger.warning("Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend.", name)
delattr(module, 'weight')
module.register_parameter('weight', actual_weight)
# refactor these magic numbers when customizations of dtype and qscheme are ready.
if hasattr(module, 'input_scale'):
calibration_config[name]['weight_bit'] = 8
calibration_config[name]['input_bit'] = 8
max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point))
min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point))
calibration_config[name]['tracked_min_input'] = min_input
Expand Down
34 changes: 34 additions & 0 deletions test/ut/sdk/test_compressor_torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import copy
from unittest import TestCase, main
import numpy as np
import torch
Expand Down Expand Up @@ -263,6 +264,39 @@ def test_torch_taylorFOweight_pruner_global_sort(self):
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))

def test_torch_observer_quantizer(self):
model = TorchModel()
# test invalid config
# only support 8bit for now
config_list = [{
'quant_types': ['weight'],
'quant_bits': 5,
'op_types': ['Conv2d', 'Linear']
}]
with self.assertRaises(schema.SchemaError):
torch_quantizer.ObserverQuantizer(model, config_list)

# weight will not change for now
model = TorchModel().eval()
origin_parameters = copy.deepcopy(dict(model.named_parameters()))

config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types': ['Conv2d', 'Linear']
}]
quantizer = torch_quantizer.ObserverQuantizer(model, config_list)
input = torch.randn(1, 1, 28, 28)
model(input)
quantizer.compress()
model_path = "test_model.pth"
calibration_path = "test_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)
new_parameters = dict(model.named_parameters())
self.assertTrue(all(torch.equal(v, new_parameters[k]) for k, v in origin_parameters.items()))
self.assertTrue(calibration_config is not None)
self.assertTrue(len(calibration_config) == 4)

def test_torch_QAT_quantizer(self):
model = TorchModel()
config_list = [{
Expand Down

0 comments on commit 656c63d

Please sign in to comment.