diff --git a/examples/model_compress/quantization/observer_quantizer.py b/examples/model_compress/quantization/observer_quantizer.py index 76a07a00da..cb8c59bd07 100644 --- a/examples/model_compress/quantization/observer_quantizer.py +++ b/examples/model_compress/quantization/observer_quantizer.py @@ -41,8 +41,8 @@ def test(model, device, test_loader): def calibration(model, device, test_loader): model.eval() with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) + for data, _ in test_loader: + data = data.to(device) model(data) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index d3c38eb682..d3b55c64ee 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -261,8 +261,6 @@ def compress(self): scale, zero_point = self.calculate_qparams(layer.name, 'weight') module.register_buffer('weight_scale', scale.to(self.device)) module.register_buffer('weight_zero_point', zero_point.to(self.device)) - # todo: recover old_weight to weight, because the compressed - # model may be further finetuned. if "input" in config.get("quant_types", []): scale, zero_point = self.calculate_qparams(layer.name, 'input') module.register_buffer('input_scale', scale.to(self.device)) @@ -301,11 +299,24 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config = {} for name, module in self.bound_model.named_modules(): - if hasattr(module, 'input_scale') or hasattr(module, 'output_scale'): + if hasattr(module, 'weight_scale') or hasattr(module, 'input_scale') or hasattr(module, 'output_scale'): calibration_config[name] = {} + if hasattr(module, 'weight_scale'): + calibration_config[name]['weight_bit'] = 8 + val = float(module.weight_scale * module.weight_qmax) + calibration_config[name]['tracked_min_weight'] = val + calibration_config[name]['tracked_max_weight'] = -val + calibration_config[name]['tracked_weight_qmin'] = -127 + calibration_config[name]['tracked_weight_qmax'] = 127 + actual_weight = getattr(module, 'old_weight', None) + if actual_weight is None: + logger.warning("Can not recover weight for layer %s. " + "This may lead to a wrong accuracy performance on the backend.", name) + delattr(module, 'weight') + module.register_parameter('weight', actual_weight) # refactor these magic numbers when customizations of dtype and qscheme are ready. if hasattr(module, 'input_scale'): - calibration_config[name]['weight_bit'] = 8 + calibration_config[name]['input_bit'] = 8 max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point)) min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point)) calibration_config[name]['tracked_min_input'] = min_input diff --git a/test/ut/sdk/test_compressor_torch.py b/test/ut/sdk/test_compressor_torch.py index c78fa982b1..53abdc34ed 100644 --- a/test/ut/sdk/test_compressor_torch.py +++ b/test/ut/sdk/test_compressor_torch.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import copy from unittest import TestCase, main import numpy as np import torch @@ -263,6 +264,39 @@ def test_torch_taylorFOweight_pruner_global_sort(self): assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.])) assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.])) + def test_torch_observer_quantizer(self): + model = TorchModel() + # test invalid config + # only support 8bit for now + config_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 5, + 'op_types': ['Conv2d', 'Linear'] + }] + with self.assertRaises(schema.SchemaError): + torch_quantizer.ObserverQuantizer(model, config_list) + + # weight will not change for now + model = TorchModel().eval() + origin_parameters = copy.deepcopy(dict(model.named_parameters())) + + config_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 8, + 'op_types': ['Conv2d', 'Linear'] + }] + quantizer = torch_quantizer.ObserverQuantizer(model, config_list) + input = torch.randn(1, 1, 28, 28) + model(input) + quantizer.compress() + model_path = "test_model.pth" + calibration_path = "test_calibration.pth" + calibration_config = quantizer.export_model(model_path, calibration_path) + new_parameters = dict(model.named_parameters()) + self.assertTrue(all(torch.equal(v, new_parameters[k]) for k, v in origin_parameters.items())) + self.assertTrue(calibration_config is not None) + self.assertTrue(len(calibration_config) == 4) + def test_torch_QAT_quantizer(self): model = TorchModel() config_list = [{