From c0821488972ba215cddf482d9437b68dcce8f9c9 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Thu, 8 Jul 2021 11:20:46 +0800 Subject: [PATCH 01/13] Add post training observer_quantizer --- .../quantization/observer_quantizer.py | 155 +++++++++++++ .../pytorch/quantization/observers.py | 3 + .../pytorch/quantization/quantizers.py | 216 +++++++++++++++++- 3 files changed, 373 insertions(+), 1 deletion(-) create mode 100644 examples/model_compress/quantization/observer_quantizer.py create mode 100644 nni/algorithms/compression/pytorch/quantization/observers.py diff --git a/examples/model_compress/quantization/observer_quantizer.py b/examples/model_compress/quantization/observer_quantizer.py new file mode 100644 index 0000000000..ac7a36a0d6 --- /dev/null +++ b/examples/model_compress/quantization/observer_quantizer.py @@ -0,0 +1,155 @@ +import torch +import torch.nn.functional as F +from torchvision import datasets, transforms +from nni.algorithms.compression.pytorch.quantization import ObserverQuantizer + + +class Mnist(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) + self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) + self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) + self.fc2 = torch.nn.Linear(500, 10) + self.relu1 = torch.nn.ReLU6() + self.relu2 = torch.nn.ReLU6() + self.relu3 = torch.nn.ReLU6() + self.max_pool1 = torch.nn.MaxPool2d(2, 2) + self.max_pool2 = torch.nn.MaxPool2d(2, 2) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.max_pool1(x) + x = self.relu2(self.conv2(x)) + x = self.max_pool2(x) + x = x.view(-1, 4 * 4 * 50) + x = self.relu3(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def train(model, device, train_loader, optimizer): + model.to(device) + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + + +def calibration(model, device, test_loader): + model.eval() + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + model(data) + + +def test_trt(engine, test_loader): + test_loss = 0 + correct = 0 + time_elasped = 0 + for data, target in test_loader: + output, time = engine.inference(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + time_elasped += time + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + print("Inference elapsed_time (whole dataset): {}s".format(time_elasped)) + + +def main(): + torch.manual_seed(0) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=True, download=True, transform=trans), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=False, transform=trans), + batch_size=1000, shuffle=True) + + model = Mnist() + configure_list = [{ + 'quant_types': ['weight', 'input'], + 'quant_bits': {'weight': 8, 'input': 8}, + 'op_names': ['conv1'], + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8, }, + 'op_names': ['relu1'], + }, { + 'quant_types': ['weight', 'input'], + 'quant_bits': {'weight': 8, 'input': 8}, + 'op_names': ['conv2'], + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8}, + 'op_names': ['relu2'], + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8}, + 'op_names': ['max_pool2'], + } + ] + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + + # Train the model to get a baseline performance + for epoch in range(5): + print('# Epoch {} #'.format(epoch)) + train(model, device, train_loader, optimizer) + + test(model, device, test_loader) + + # Construct the ObserverQuantizer. Note that currently ObserverQuantizer only works + # in evaluation mode. + quantizer = ObserverQuantizer(model.eval(), configure_list, optimizer) + # Use the test data set to do calibration, this will not change the model parameters + calibration(model, device, test_loader) + # obtain the quantization information and switch the model to "accuracy verification" mode + quantizer.compress() + + # measure the accuracy of the quantized model. + test(model, device, test_loader) + + model_path = "mnist_model.pth" + calibration_path = "mnist_calibration.pth" + calibration_config = quantizer.export_model(model_path, calibration_path) + print("calibration_config: ", calibration_config) + + # For now the quantization settings of ObserverQuantizer does not match the TensorRT, + # so TensorRT conversion are not supported + # current settings: + # weight : per_tensor_symmetric, qint8 + # activation : per_tensor_affine, quint8, reduce_range=True + + +if __name__ == '__main__': + main() diff --git a/nni/algorithms/compression/pytorch/quantization/observers.py b/nni/algorithms/compression/pytorch/quantization/observers.py new file mode 100644 index 0000000000..7631f46ccd --- /dev/null +++ b/nni/algorithms/compression/pytorch/quantization/observers.py @@ -0,0 +1,3 @@ +from torch.quantization import default_weight_observer, default_histogram_observer + +__all__ = ["default_weight_observer", "default_histogram_observer"] diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 09e53329b0..145caa44bb 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -4,11 +4,14 @@ import logging import copy import torch +from collections import defaultdict from schema import Schema, And, Or, Optional from nni.compression.pytorch.utils.config_validation import QuantizerSchema from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType -__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] +from .observers import default_weight_observer, default_histogram_observer + +__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer', 'ObserverQuantizer'] logger = logging.getLogger(__name__) @@ -120,6 +123,217 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma return grad_output +class ObserverQuantizer(Quantizer): + """ + + """ + + def __init__(self, model, config_list, optimizer=None): + super().__init__(model, config_list, optimizer) + # NOTE: this quantizer is experimental for now. The dtype and qscheme of quantization + # is hard-coded. + # TODO: + # 1. support dtype and qscheme customization through config_list. Current settings: + # weight observer : per_tensor_symmetric, qint8 + # activation observer : per_tensor_affine, quint8, reduce_range=True + # 2. add more kinds of observers, such as Kullback-Leibler divergence. + # 3. add batch normalization folding + assert not model.training, "Currently observer quantizer only works in evaluation mode." + self.quant_grad = QuantForward() + self.device = next(model.parameters()).device + modules_to_compress = self.get_modules_to_compress() + all_observers = defaultdict(dict) + weight_q_min, weight_q_max = -127, 127 + activation_q_min, activation_q_max = 0, 127 # reduce_range is set to True + self.compressed = False + + for layer, config in modules_to_compress: + layer_name = layer.name + module = layer.module + if "weight" in config.get("quant_types", []): + all_observers[layer_name]["weight"] = default_weight_observer() + setattr(module, "weight_qmax", weight_q_max) + setattr(module, "weight_qmin", weight_q_min) + if "input" in config.get("quant_types", []): + all_observers[layer_name]["input"] = default_histogram_observer() + setattr(module, "input_qmax", activation_q_max) + setattr(module, "input_qmin", activation_q_min) + if "output" in config.get("quant_types", []): + all_observers[layer_name]["output"] = default_histogram_observer() + setattr(module, "output_qmax", activation_q_max) + setattr(module, "output_qmin", activation_q_min) + self.all_observers = all_observers + self.bound_model.to(self.device) + + def validate_config(self, model, config_list): + schema = CompressorSchema([{ + Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]), + Optional('quant_bits'): Or(And(int, lambda n: n == 8), Schema({ + Optional('weight'): And(int, lambda n: n == 8), + Optional('output'): And(int, lambda n: n == 8), + Optional('input'): And(int, lambda n: n == 8), + })), + Optional('op_types'): [str], + Optional('op_names'): [str] + }], model, logger) + + schema.validate(config_list) + + def record(self, wrapper, type, tensor): + name = wrapper.name + observer = self.all_observers[name][type] + if isinstance(tensor, tuple): + # NB: This only works for single tensor + tensor = (t.cpu() for t in tensor) + observer(*tensor) + else: + observer(tensor.cpu()) + + def calculate_qparams(self, name, type): + observer = self.all_observers[name][type] + scale, zero_point = observer.calculate_qparams() + return scale, zero_point + + def quantize(self, x, scale, zero_point, qmin, qmax): + x = x / scale + zero_point + x = torch.clamp(x, qmin, qmax) + x = torch.round(x) + x = (x - zero_point) * scale + return x + + def quantize_input(self, *inputs, wrapper, **kwargs): + if self.compressed: + module = wrapper.module + new_input = self.quantize(inputs[0], + module.input_scale, + module.input_zero_point, + module.input_qmin, + module.input_qmax) + list_inp = list(inputs) + list_inp[0] = new_input + inputs = tuple(list_inp) + else: + self.record(wrapper, 'input', inputs) + return inputs + + def quantize_weight(self, wrapper, **kwargs): + module = wrapper.module + old_weight = module.weight + if self.compressed: + new_weight = self.quantize(old_weight, + module.weight_scale, + module.weight_zero_point, + module.weight_qmin, + module.weight_qmax) + else: + self.record(wrapper, 'weight', old_weight) + new_weight = old_weight + return new_weight + + def quantize_output(self, output, wrapper, **kwargs): + if self.compressed: + module = wrapper.module + new_output = self.quantize(output, + module.output_scale, + module.output_zero_point, + module.output_qmin, + module.output_qmax) + else: + self.record(wrapper, 'output', output) + new_output = output + return new_output + + def compress(self): + """ + Calculate quantization information of each tensor. Note that the inference of + the compressed model will no longer update the corresponding. Instead, the quantization + process will be simulated, which is used to test the accuracy of the quantization. + """ + modules_to_compress = self.get_modules_to_compress() + for layer, config in modules_to_compress: + module = layer.module + if "weight" in config.get("quant_types", []): + scale, zero_point = self.calculate_qparams(layer.name, 'weight') + module.register_buffer('weight_scale', scale.to(self.device)) + module.register_buffer('weight_zero_point', zero_point.to(self.device)) + # todo: recover old_weight to weight, because the compressed + # model may be further finetuned. + if "input" in config.get("quant_types", []): + scale, zero_point = self.calculate_qparams(layer.name, 'input') + module.register_buffer('input_scale', scale.to(self.device)) + module.register_buffer('input_zero_point', zero_point.to(self.device)) + if "output" in config.get("quant_types", []): + scale, zero_point = self.calculate_qparams(layer.name, 'output') + module.register_buffer('output_scale', scale.to(self.device)) + module.register_buffer('output_zero_point', zero_point.to(self.device)) + self.compressed = True + super().compress() + + def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): + """ + Export quantized model weights and calibration parameters(optional) + + Parameters + ---------- + model_path : str + path to save quantized model weight + calibration_path : str + (optional) path to save quantize parameters after calibration + onnx_path : str + (optional) path to save onnx model + input_shape : list or tuple + input shape to onnx model + device : torch.device + device of the model, used to place the dummy input tensor for exporting onnx file. + the tensor is placed on cpu if ```device``` is None + + Returns + ------- + Dict + """ + assert model_path is not None, 'model_path must be specified' + self._unwrap_model() + calibration_config = {} + + for name, module in self.bound_model.named_modules(): + if hasattr(module, 'input_scale') or hasattr(module, 'output_scale'): + calibration_config[name] = {} + # refactor these magic numbers when customizations of dtype and qscheme are ready. + if hasattr(module, 'input_scale'): + calibration_config[name]['weight_bit'] = 8 + max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point)) + min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point)) + calibration_config[name]['tracked_min_input'] = min_input + calibration_config[name]['tracked_max_input'] = max_input + calibration_config[name]['tracked_input_qmin'] = 0 + calibration_config[name]['tracked_input_qmax'] = 255 + if hasattr(module, 'output_scale'): + calibration_config[name]['activation_bit'] = 8 + max_input = float(module.output_scale * (module.output_qmax - module.output_zero_point)) + min_input = float(module.output_scale * (module.output_qmin - module.output_zero_point)) + calibration_config[name]['tracked_min_activation'] = min_input + calibration_config[name]['tracked_max_activation'] = max_input + calibration_config[name]['tracked_activation_qmin'] = 0 + calibration_config[name]['tracked_activation_qmax'] = 255 + self._del_simulated_attr(module) + + self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, + input_shape, device) + + return calibration_config + + def _del_simulated_attr(self, module): + """ + delete redundant parameters in quantize module + """ + del_attr_list = ['old_weight', 'steps', 'weight_qmax', 'weight_qmin', 'input_qmax', 'input_qmin', + 'output_qmax', 'output_qmin', 'weight_scale', 'weight_zero_point', 'input_scale', + 'input_zero_point', 'output_scale', 'output_zero_point'] + for attr in del_attr_list: + if hasattr(module, attr): + delattr(module, attr) + + class QAT_Quantizer(Quantizer): """Quantizer defined in: Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference From 11a40c31682316b7ef3b9e32e49e118f06bf37af Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Thu, 8 Jul 2021 12:58:43 +0800 Subject: [PATCH 02/13] fix linter --- .../compression/pytorch/quantization/quantizers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 145caa44bb..58af5c6c71 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -3,8 +3,8 @@ import logging import copy -import torch from collections import defaultdict +import torch from schema import Schema, And, Or, Optional from nni.compression.pytorch.utils.config_validation import QuantizerSchema from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward, QuantGrad, QuantType @@ -179,9 +179,9 @@ def validate_config(self, model, config_list): schema.validate(config_list) - def record(self, wrapper, type, tensor): + def record(self, wrapper, quant_type, tensor): name = wrapper.name - observer = self.all_observers[name][type] + observer = self.all_observers[name][quant_type] if isinstance(tensor, tuple): # NB: This only works for single tensor tensor = (t.cpu() for t in tensor) @@ -189,8 +189,8 @@ def record(self, wrapper, type, tensor): else: observer(tensor.cpu()) - def calculate_qparams(self, name, type): - observer = self.all_observers[name][type] + def calculate_qparams(self, name, quant_type): + observer = self.all_observers[name][quant_type] scale, zero_point = observer.calculate_qparams() return scale, zero_point From b717e53c5cb5cf1835e74db3cccc2281f8f55de4 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 13 Jul 2021 10:13:52 +0800 Subject: [PATCH 03/13] refine --- .../quantization/observer_quantizer.py | 46 ++----------------- .../pytorch/quantization/quantizers.py | 16 +++++-- 2 files changed, 15 insertions(+), 47 deletions(-) diff --git a/examples/model_compress/quantization/observer_quantizer.py b/examples/model_compress/quantization/observer_quantizer.py index ac7a36a0d6..76a07a00da 100644 --- a/examples/model_compress/quantization/observer_quantizer.py +++ b/examples/model_compress/quantization/observer_quantizer.py @@ -2,30 +2,9 @@ import torch.nn.functional as F from torchvision import datasets, transforms from nni.algorithms.compression.pytorch.quantization import ObserverQuantizer - - -class Mnist(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) - self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) - self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) - self.fc2 = torch.nn.Linear(500, 10) - self.relu1 = torch.nn.ReLU6() - self.relu2 = torch.nn.ReLU6() - self.relu3 = torch.nn.ReLU6() - self.max_pool1 = torch.nn.MaxPool2d(2, 2) - self.max_pool2 = torch.nn.MaxPool2d(2, 2) - - def forward(self, x): - x = self.relu1(self.conv1(x)) - x = self.max_pool1(x) - x = self.relu2(self.conv2(x)) - x = self.max_pool2(x) - x = x.view(-1, 4 * 4 * 50) - x = self.relu3(self.fc1(x)) - x = self.fc2(x) - return F.log_softmax(x, dim=1) +import sys +sys.path.append('../models') +from mnist.naive import NaiveModel def train(model, device, train_loader, optimizer): @@ -67,23 +46,6 @@ def calibration(model, device, test_loader): model(data) -def test_trt(engine, test_loader): - test_loss = 0 - correct = 0 - time_elasped = 0 - for data, target in test_loader: - output, time = engine.inference(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - time_elasped += time - test_loss /= len(test_loader.dataset) - - print('Loss: {} Accuracy: {}%'.format( - test_loss, 100 * correct / len(test_loader.dataset))) - print("Inference elapsed_time (whole dataset): {}s".format(time_elasped)) - - def main(): torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -96,7 +58,7 @@ def main(): datasets.MNIST('data', train=False, transform=trans), batch_size=1000, shuffle=True) - model = Mnist() + model = NaiveModel() configure_list = [{ 'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 58af5c6c71..edc3152148 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -125,7 +125,13 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma class ObserverQuantizer(Quantizer): """ - + This quantizer uses observers to record weight/activation statistics to get quantization + information. The whole process can be divided into three steps: + 1. It will register observers to the place where quantization would happen (just like registering hooks). + 2. The observers would record tensors' statistics during calibration. + 3. Scale & zero point would be obtained after calibration. + Note that the observer type, tensor dtype and quantization qscheme are hard coded for now. Their customization + are under development and will be ready soon. """ def __init__(self, model, config_list, optimizer=None): @@ -194,7 +200,7 @@ def calculate_qparams(self, name, quant_type): scale, zero_point = observer.calculate_qparams() return scale, zero_point - def quantize(self, x, scale, zero_point, qmin, qmax): + def _quantize(self, x, scale, zero_point, qmin, qmax): x = x / scale + zero_point x = torch.clamp(x, qmin, qmax) x = torch.round(x) @@ -204,7 +210,7 @@ def quantize(self, x, scale, zero_point, qmin, qmax): def quantize_input(self, *inputs, wrapper, **kwargs): if self.compressed: module = wrapper.module - new_input = self.quantize(inputs[0], + new_input = self._quantize(inputs[0], module.input_scale, module.input_zero_point, module.input_qmin, @@ -220,7 +226,7 @@ def quantize_weight(self, wrapper, **kwargs): module = wrapper.module old_weight = module.weight if self.compressed: - new_weight = self.quantize(old_weight, + new_weight = self._quantize(old_weight, module.weight_scale, module.weight_zero_point, module.weight_qmin, @@ -233,7 +239,7 @@ def quantize_weight(self, wrapper, **kwargs): def quantize_output(self, output, wrapper, **kwargs): if self.compressed: module = wrapper.module - new_output = self.quantize(output, + new_output = self._quantize(output, module.output_scale, module.output_zero_point, module.output_qmin, From 21bfa52abd2e2bdc5e6105710b429c343d2920eb Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 13 Jul 2021 11:01:49 +0800 Subject: [PATCH 04/13] refine --- .../compression/pytorch/quantization/quantizers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index edc3152148..6c8ff875ba 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -124,9 +124,8 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma class ObserverQuantizer(Quantizer): - """ - This quantizer uses observers to record weight/activation statistics to get quantization - information. The whole process can be divided into three steps: + """This quantizer uses observers to record weight/activation statistics to get quantization information. + The whole process can be divided into three steps: 1. It will register observers to the place where quantization would happen (just like registering hooks). 2. The observers would record tensors' statistics during calibration. 3. Scale & zero point would be obtained after calibration. From d124415a1628e7b0fd598a77ae48d037d17e009b Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 13 Jul 2021 11:06:37 +0800 Subject: [PATCH 05/13] rebase --- nni/algorithms/compression/pytorch/quantization/quantizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 6c8ff875ba..d3c38eb682 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -171,7 +171,7 @@ def __init__(self, model, config_list, optimizer=None): self.bound_model.to(self.device) def validate_config(self, model, config_list): - schema = CompressorSchema([{ + schema = QuantizerSchema([{ Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]), Optional('quant_bits'): Or(And(int, lambda n: n == 8), Schema({ Optional('weight'): And(int, lambda n: n == 8), From 656c63d7fb6cb3d0b0b4688e013eb1f6f7af56ab Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 26 Jul 2021 12:17:59 +0800 Subject: [PATCH 06/13] refine --- .../quantization/observer_quantizer.py | 4 +-- .../pytorch/quantization/quantizers.py | 19 ++++++++--- test/ut/sdk/test_compressor_torch.py | 34 +++++++++++++++++++ 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/examples/model_compress/quantization/observer_quantizer.py b/examples/model_compress/quantization/observer_quantizer.py index 76a07a00da..cb8c59bd07 100644 --- a/examples/model_compress/quantization/observer_quantizer.py +++ b/examples/model_compress/quantization/observer_quantizer.py @@ -41,8 +41,8 @@ def test(model, device, test_loader): def calibration(model, device, test_loader): model.eval() with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) + for data, _ in test_loader: + data = data.to(device) model(data) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index d3c38eb682..d3b55c64ee 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -261,8 +261,6 @@ def compress(self): scale, zero_point = self.calculate_qparams(layer.name, 'weight') module.register_buffer('weight_scale', scale.to(self.device)) module.register_buffer('weight_zero_point', zero_point.to(self.device)) - # todo: recover old_weight to weight, because the compressed - # model may be further finetuned. if "input" in config.get("quant_types", []): scale, zero_point = self.calculate_qparams(layer.name, 'input') module.register_buffer('input_scale', scale.to(self.device)) @@ -301,11 +299,24 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config = {} for name, module in self.bound_model.named_modules(): - if hasattr(module, 'input_scale') or hasattr(module, 'output_scale'): + if hasattr(module, 'weight_scale') or hasattr(module, 'input_scale') or hasattr(module, 'output_scale'): calibration_config[name] = {} + if hasattr(module, 'weight_scale'): + calibration_config[name]['weight_bit'] = 8 + val = float(module.weight_scale * module.weight_qmax) + calibration_config[name]['tracked_min_weight'] = val + calibration_config[name]['tracked_max_weight'] = -val + calibration_config[name]['tracked_weight_qmin'] = -127 + calibration_config[name]['tracked_weight_qmax'] = 127 + actual_weight = getattr(module, 'old_weight', None) + if actual_weight is None: + logger.warning("Can not recover weight for layer %s. " + "This may lead to a wrong accuracy performance on the backend.", name) + delattr(module, 'weight') + module.register_parameter('weight', actual_weight) # refactor these magic numbers when customizations of dtype and qscheme are ready. if hasattr(module, 'input_scale'): - calibration_config[name]['weight_bit'] = 8 + calibration_config[name]['input_bit'] = 8 max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point)) min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point)) calibration_config[name]['tracked_min_input'] = min_input diff --git a/test/ut/sdk/test_compressor_torch.py b/test/ut/sdk/test_compressor_torch.py index c78fa982b1..53abdc34ed 100644 --- a/test/ut/sdk/test_compressor_torch.py +++ b/test/ut/sdk/test_compressor_torch.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import copy from unittest import TestCase, main import numpy as np import torch @@ -263,6 +264,39 @@ def test_torch_taylorFOweight_pruner_global_sort(self): assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.])) assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.])) + def test_torch_observer_quantizer(self): + model = TorchModel() + # test invalid config + # only support 8bit for now + config_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 5, + 'op_types': ['Conv2d', 'Linear'] + }] + with self.assertRaises(schema.SchemaError): + torch_quantizer.ObserverQuantizer(model, config_list) + + # weight will not change for now + model = TorchModel().eval() + origin_parameters = copy.deepcopy(dict(model.named_parameters())) + + config_list = [{ + 'quant_types': ['weight'], + 'quant_bits': 8, + 'op_types': ['Conv2d', 'Linear'] + }] + quantizer = torch_quantizer.ObserverQuantizer(model, config_list) + input = torch.randn(1, 1, 28, 28) + model(input) + quantizer.compress() + model_path = "test_model.pth" + calibration_path = "test_calibration.pth" + calibration_config = quantizer.export_model(model_path, calibration_path) + new_parameters = dict(model.named_parameters()) + self.assertTrue(all(torch.equal(v, new_parameters[k]) for k, v in origin_parameters.items())) + self.assertTrue(calibration_config is not None) + self.assertTrue(len(calibration_config) == 4) + def test_torch_QAT_quantizer(self): model = TorchModel() config_list = [{ From 8aef2c22e57d1cf39ffd634823ed492639907467 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 26 Jul 2021 17:05:21 +0800 Subject: [PATCH 07/13] refine --- nni/algorithms/compression/pytorch/quantization/quantizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index d3b55c64ee..617adb3227 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -143,7 +143,7 @@ def __init__(self, model, config_list, optimizer=None): # activation observer : per_tensor_affine, quint8, reduce_range=True # 2. add more kinds of observers, such as Kullback-Leibler divergence. # 3. add batch normalization folding - assert not model.training, "Currently observer quantizer only works in evaluation mode." + assert not model.training, "Currently the observer quantizer only works in evaluation mode." self.quant_grad = QuantForward() self.device = next(model.parameters()).device modules_to_compress = self.get_modules_to_compress() From 455cc122d295268abaa84967dda86b69f0eb2f6d Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 27 Jul 2021 09:53:43 +0800 Subject: [PATCH 08/13] refine --- nni/algorithms/compression/pytorch/quantization/quantizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 617adb3227..6e7efff09c 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -304,8 +304,8 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ if hasattr(module, 'weight_scale'): calibration_config[name]['weight_bit'] = 8 val = float(module.weight_scale * module.weight_qmax) - calibration_config[name]['tracked_min_weight'] = val - calibration_config[name]['tracked_max_weight'] = -val + calibration_config[name]['tracked_max_weight'] = val + calibration_config[name]['tracked_min_weight'] = -val calibration_config[name]['tracked_weight_qmin'] = -127 calibration_config[name]['tracked_weight_qmax'] = 127 actual_weight = getattr(module, 'old_weight', None) From ab32e8e540a07cd475c584c7bdaf5677214ec7ec Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 27 Jul 2021 14:33:42 +0800 Subject: [PATCH 09/13] simulate quantization in export_model --- .../compression/pytorch/quantization/quantizers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 6e7efff09c..2097008b53 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -312,8 +312,15 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ if actual_weight is None: logger.warning("Can not recover weight for layer %s. " "This may lead to a wrong accuracy performance on the backend.", name) + continue + # simulate quantization + actual_quantized_weight = self._quantize(actual_weight, + module.weight_scale, + module.weight_zero_point, + module.weight_qmin, + module.weight_qmax) delattr(module, 'weight') - module.register_parameter('weight', actual_weight) + module.register_parameter('weight', torch.nn.Parameter(actual_quantized_weight)) # refactor these magic numbers when customizations of dtype and qscheme are ready. if hasattr(module, 'input_scale'): calibration_config[name]['input_bit'] = 8 From 28be4fe0f75976625ecb950766e234c7e7877df7 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 27 Jul 2021 14:54:15 +0800 Subject: [PATCH 10/13] assign new_weight to weight in eval mode --- nni/algorithms/compression/pytorch/quantization/quantizers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 2097008b53..59002b2549 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -230,6 +230,7 @@ def quantize_weight(self, wrapper, **kwargs): module.weight_zero_point, module.weight_qmin, module.weight_qmax) + module.weight = new_weight else: self.record(wrapper, 'weight', old_weight) new_weight = old_weight @@ -313,7 +314,7 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ logger.warning("Can not recover weight for layer %s. " "This may lead to a wrong accuracy performance on the backend.", name) continue - # simulate quantization + # simulate quantization. actual_quantized_weight = self._quantize(actual_weight, module.weight_scale, module.weight_zero_point, From 455340a6b446f86573bfc77eac4e2f791dd97091 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 27 Jul 2021 16:01:51 +0800 Subject: [PATCH 11/13] quantize weight in compress --- .../pytorch/quantization/quantizers.py | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 59002b2549..07dac56be1 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -222,19 +222,14 @@ def quantize_input(self, *inputs, wrapper, **kwargs): return inputs def quantize_weight(self, wrapper, **kwargs): + # If ObserverQuantizer.compress is executed, the weight will be set to + # the Pseudo-quantized one. So there is no need to quantize it + if self.compressed: + return + module = wrapper.module old_weight = module.weight - if self.compressed: - new_weight = self._quantize(old_weight, - module.weight_scale, - module.weight_zero_point, - module.weight_qmin, - module.weight_qmax) - module.weight = new_weight - else: - self.record(wrapper, 'weight', old_weight) - new_weight = old_weight - return new_weight + self.record(wrapper, 'weight', old_weight) def quantize_output(self, output, wrapper, **kwargs): if self.compressed: @@ -262,6 +257,14 @@ def compress(self): scale, zero_point = self.calculate_qparams(layer.name, 'weight') module.register_buffer('weight_scale', scale.to(self.device)) module.register_buffer('weight_zero_point', zero_point.to(self.device)) + weight = module.weight + quantized_weight = self._quantize(weight, + module.weight_scale, + module.weight_zero_point, + module.weight_qmin, + module.weight_qmax) + delattr(module, 'weight') + module.register_parameter('weight', torch.nn.Parameter(quantized_weight)) if "input" in config.get("quant_types", []): scale, zero_point = self.calculate_qparams(layer.name, 'input') module.register_buffer('input_scale', scale.to(self.device)) @@ -309,19 +312,6 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config[name]['tracked_min_weight'] = -val calibration_config[name]['tracked_weight_qmin'] = -127 calibration_config[name]['tracked_weight_qmax'] = 127 - actual_weight = getattr(module, 'old_weight', None) - if actual_weight is None: - logger.warning("Can not recover weight for layer %s. " - "This may lead to a wrong accuracy performance on the backend.", name) - continue - # simulate quantization. - actual_quantized_weight = self._quantize(actual_weight, - module.weight_scale, - module.weight_zero_point, - module.weight_qmin, - module.weight_qmax) - delattr(module, 'weight') - module.register_parameter('weight', torch.nn.Parameter(actual_quantized_weight)) # refactor these magic numbers when customizations of dtype and qscheme are ready. if hasattr(module, 'input_scale'): calibration_config[name]['input_bit'] = 8 From b67bda511b549009f37348c4e7a14e4ad4015ed2 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 27 Jul 2021 17:37:06 +0800 Subject: [PATCH 12/13] update ut --- test/ut/sdk/test_compressor_torch.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/ut/sdk/test_compressor_torch.py b/test/ut/sdk/test_compressor_torch.py index 53abdc34ed..32f503354c 100644 --- a/test/ut/sdk/test_compressor_torch.py +++ b/test/ut/sdk/test_compressor_torch.py @@ -289,11 +289,18 @@ def test_torch_observer_quantizer(self): input = torch.randn(1, 1, 28, 28) model(input) quantizer.compress() + buffers = dict(model.named_buffers()) + scales = {k: v for k, v in buffers.items() if 'scale' in k} model_path = "test_model.pth" calibration_path = "test_calibration.pth" calibration_config = quantizer.export_model(model_path, calibration_path) new_parameters = dict(model.named_parameters()) - self.assertTrue(all(torch.equal(v, new_parameters[k]) for k, v in origin_parameters.items())) + for layer_name, v in calibration_config.items(): + scale_name = layer_name + '.module.weight_scale' + weight_name = layer_name + '.weight' + s = float(scales[scale_name]) + self.assertTrue(torch.allclose(origin_parameters[weight_name], new_parameters[weight_name], atol=0.5 * s)) + self.assertTrue(calibration_config is not None) self.assertTrue(len(calibration_config) == 4) From 365d81287c2dcf2b6d250ed1559c598623c3436c Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 27 Jul 2021 19:47:01 +0800 Subject: [PATCH 13/13] fix wrong range --- nni/algorithms/compression/pytorch/quantization/quantizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 07dac56be1..da2acded8b 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -320,7 +320,7 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config[name]['tracked_min_input'] = min_input calibration_config[name]['tracked_max_input'] = max_input calibration_config[name]['tracked_input_qmin'] = 0 - calibration_config[name]['tracked_input_qmax'] = 255 + calibration_config[name]['tracked_input_qmax'] = 127 if hasattr(module, 'output_scale'): calibration_config[name]['activation_bit'] = 8 max_input = float(module.output_scale * (module.output_qmax - module.output_zero_point)) @@ -328,7 +328,7 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config[name]['tracked_min_activation'] = min_input calibration_config[name]['tracked_max_activation'] = max_input calibration_config[name]['tracked_activation_qmin'] = 0 - calibration_config[name]['tracked_activation_qmax'] = 255 + calibration_config[name]['tracked_activation_qmax'] = 127 self._del_simulated_attr(module) self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path,