From f0e3c584f415a39f77d35b098b7fb5fc14ccc5fc Mon Sep 17 00:00:00 2001 From: lin bin Date: Fri, 9 Apr 2021 09:34:08 +0800 Subject: [PATCH] Combine tensorrt tool with NNI quantization algorithms. (#3488) --- .../mixed_precision_speedup_mnist.py | 169 ++++++++ .../pytorch/quantization/quantizers.py | 42 +- nni/compression/pytorch/compressor.py | 10 +- .../pytorch/quantization_speedup/__init__.py | 1 + .../pytorch/quantization_speedup/backend.py | 51 +++ .../quantization_speedup/calibrator.py | 99 +++++ .../quantization_speedup/frontend_to_onnx.py | 148 +++++++ .../integrated_tensorrt.py | 381 ++++++++++++++++++ .../quantization_speedup/trt_pycuda.py | 86 ++++ pylintrc | 4 +- test/ut/sdk/test_compressor_torch.py | 15 +- 11 files changed, 977 insertions(+), 29 deletions(-) create mode 100644 examples/model_compress/quantization/mixed_precision_speedup_mnist.py create mode 100644 nni/compression/pytorch/quantization_speedup/__init__.py create mode 100644 nni/compression/pytorch/quantization_speedup/backend.py create mode 100644 nni/compression/pytorch/quantization_speedup/calibrator.py create mode 100644 nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py create mode 100644 nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py create mode 100644 nni/compression/pytorch/quantization_speedup/trt_pycuda.py diff --git a/examples/model_compress/quantization/mixed_precision_speedup_mnist.py b/examples/model_compress/quantization/mixed_precision_speedup_mnist.py new file mode 100644 index 0000000000..bdcdcb7f5f --- /dev/null +++ b/examples/model_compress/quantization/mixed_precision_speedup_mnist.py @@ -0,0 +1,169 @@ +import torch +import torch.nn.functional as F +from torchvision import datasets, transforms + +from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer +from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT + +class Mnist(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) + self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) + self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) + self.fc2 = torch.nn.Linear(500, 10) + self.relu1 = torch.nn.ReLU6() + self.relu2 = torch.nn.ReLU6() + self.relu3 = torch.nn.ReLU6() + self.max_pool1 = torch.nn.MaxPool2d(2, 2) + self.max_pool2 = torch.nn.MaxPool2d(2, 2) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.max_pool1(x) + x = self.relu2(self.conv2(x)) + x = self.max_pool2(x) + x = x.view(-1, 4 * 4 * 50) + x = self.relu3(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def train(model, device, train_loader, optimizer): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + +def test_trt(engine, test_loader): + test_loss = 0 + correct = 0 + time_elasped = 0 + for data, target in test_loader: + output, time = engine.inference(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + time_elasped += time + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + print("Inference elapsed_time (whole dataset): {}s".format(time_elasped)) + +def post_training_quantization_example(train_loader, test_loader, device): + model = Mnist() + + config = { + 'conv1':{'weight_bit':8, 'activation_bit':8}, + 'conv2':{'weight_bit':32, 'activation_bit':32}, + 'fc1':{'weight_bit':16, 'activation_bit':16}, + 'fc2':{'weight_bit':8, 'activation_bit':8} + } + + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + + model.to(device) + for epoch in range(1): + print('# Epoch {} #'.format(epoch)) + train(model, device, train_loader, optimizer) + test(model, device, test_loader) + + batch_size = 32 + input_shape = (batch_size, 1, 28, 28) + + engine = ModelSpeedupTensorRT(model, input_shape, config=config, calib_data_loader=train_loader, batchsize=batch_size) + engine.compress() + test_trt(engine, test_loader) + +def quantization_aware_training_example(train_loader, test_loader, device): + model = Mnist() + + configure_list = [{ + 'quant_types': ['weight', 'output'], + 'quant_bits': {'weight':8, 'output':8}, + 'op_names': ['conv1'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output':8}, + 'op_names': ['relu1'] + }, { + 'quant_types': ['weight', 'output'], + 'quant_bits': {'weight':8, 'output':8}, + 'op_names': ['conv2'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output':8}, + 'op_names': ['relu2'] + } + ] + + # finetune the model by using QAT + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + quantizer = QAT_Quantizer(model, configure_list, optimizer) + quantizer.compress() + + model.to(device) + for epoch in range(1): + print('# Epoch {} #'.format(epoch)) + train(model, device, train_loader, optimizer) + test(model, device, test_loader) + + model_path = "mnist_model.pth" + calibration_path = "mnist_calibration.pth" + calibration_config = quantizer.export_model(model_path, calibration_path) + + test(model, device, test_loader) + + print("calibration_config: ", calibration_config) + + batch_size = 32 + input_shape = (batch_size, 1, 28, 28) + + engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size) + engine.compress() + + test_trt(engine, test_loader) + +def main(): + torch.manual_seed(0) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=True, download=True, transform=trans), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=False, transform=trans), + batch_size=1000, shuffle=True) + + # post-training quantization on TensorRT + post_training_quantization_example(train_loader, test_loader, device) + + # combine NNI quantization algorithm QAT with backend framework TensorRT + quantization_aware_training_example(train_loader, test_loader, device) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 18345fac32..ca40e30e45 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -152,22 +152,23 @@ def __init__(self, model, config_list, optimizer=None): for layer, config in modules_to_compress: layer.module.register_buffer("zero_point", torch.Tensor([0.0])) layer.module.register_buffer("scale", torch.Tensor([1.0])) + layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) if "weight" in config.get("quant_types", []): layer.module.register_buffer('weight_bit', torch.zeros(1)) + layer.module.register_buffer('tracked_min_input', torch.zeros(1)) + layer.module.register_buffer('tracked_max_input', torch.zeros(1)) if "output" in config.get("quant_types", []): layer.module.register_buffer('activation_bit', torch.zeros(1)) - layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) - layer.module.register_buffer('tracked_min_biased', torch.zeros(1)) - layer.module.register_buffer('tracked_min', torch.zeros(1)) - layer.module.register_buffer('tracked_max_biased', torch.zeros(1)) - layer.module.register_buffer('tracked_max', torch.zeros(1)) + layer.module.register_buffer('tracked_min_activation', torch.zeros(1)) + layer.module.register_buffer('tracked_max_activation', torch.zeros(1)) + def _del_simulated_attr(self, module): """ delete redundant parameters in quantize module """ - del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_biased', 'tracked_max_biased', 'tracked_min', \ - 'tracked_max', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] + del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \ + 'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr) @@ -243,15 +244,26 @@ def _dequantize(self, op, quantized_val): def quantize_weight(self, wrapper, **kwargs): config = wrapper.config module = wrapper.module + input = kwargs['input_tensor'] weight = copy.deepcopy(wrapper.module.old_weight.data) weight_bits = get_bits_length(config, 'weight') quant_start_step = config.get('quant_start_step', 0) assert weight_bits >= 1, "quant bits length should be at least 1" # we dont update weight in evaluation stage - if quant_start_step > self.bound_model.steps or not wrapper.training: + if quant_start_step > self.bound_model.steps: + module.tracked_min_input, module.tracked_max_input = torch.min(input), torch.max(input) + return weight + + if not wrapper.training: return weight + current_min, current_max = torch.min(input), torch.max(input) + module.tracked_min_input = update_ema(module.tracked_min_input, current_min, + module.ema_decay) + module.tracked_max_input = update_ema(module.tracked_max_input, current_max, + module.ema_decay) + # if bias exists, quantize bias to uint32 if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None: bias = wrapper.module.bias.data @@ -281,17 +293,17 @@ def quantize_output(self, output, wrapper, **kwargs): assert output_bits >= 1, "quant bits length should be at least 1" if quant_start_step > self.bound_model.steps: - module.tracked_min_biased, module.tracked_max_biased = torch.min(output), torch.max(output) + module.tracked_min_activation, module.tracked_max_activation = torch.min(output), torch.max(output) return output # we dont update output quantization parameters in evaluation stage if wrapper.training: current_min, current_max = torch.min(output), torch.max(output) - module.tracked_min_biased = update_ema(module.tracked_min_biased, current_min, + module.tracked_min_activation = update_ema(module.tracked_min_activation, current_min, module.ema_decay) - module.tracked_max_biased = update_ema(module.tracked_max_biased, current_max, + module.tracked_max_activation = update_ema(module.tracked_max_activation, current_max, module.ema_decay) - module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_biased, module.tracked_max_biased) + module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_activation, module.tracked_max_activation) out = self._quantize(output_bits, module, output) out = self._dequantize(module, out) return out @@ -327,10 +339,12 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ calibration_config[name] = {} if hasattr(module, 'weight_bit'): calibration_config[name]['weight_bit'] = int(module.weight_bit) + calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input) + calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input) if hasattr(module, 'activation_bit'): calibration_config[name]['activation_bit'] = int(module.activation_bit) - calibration_config[name]['tracked_min'] = float(module.tracked_min_biased) - calibration_config[name]['tracked_max'] = float(module.tracked_max_biased) + calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation) + calibration_config[name]['tracked_max_activation'] = float(module.tracked_max_activation) self._del_simulated_attr(module) self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index 66a1d26220..7fecdc3b4f 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -483,7 +483,7 @@ def forward(self, *inputs): self.quantizer.quant_grad.apply( self.module.old_weight, QuantType.QUANT_WEIGHT, - self) + self, inputs[0]) result = self.module(*inputs) else: result = self.module(*inputs) @@ -511,14 +511,12 @@ def __init__(self, model, config_list, optimizer=None): # and it is trainable, therefore, it should be added to optimizer. self.optimizer.add_param_group({"params": wrapper.module.old_weight}) - def quantize_weight(self, weight, wrapper, **kwargs): + def quantize_weight(self, wrapper, **kwargs): """ quantize should overload this method to quantize weight. This method is effectively hooked to :meth:`forward` of the model. Parameters ---------- - weight : Tensor - weight that needs to be quantized wrapper : QuantizerModuleWrapper the wrapper for origin module """ @@ -720,11 +718,11 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma return grad_output @staticmethod - def forward(ctx, tensor, quant_type, wrapper, **kwargs): + def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs): if quant_type == QuantType.QUANT_INPUT: output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) elif quant_type == QuantType.QUANT_WEIGHT: - output = wrapper.quantizer.quantize_weight(wrapper, **kwargs) + output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) elif quant_type == QuantType.QUANT_OUTPUT: output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) else: diff --git a/nni/compression/pytorch/quantization_speedup/__init__.py b/nni/compression/pytorch/quantization_speedup/__init__.py new file mode 100644 index 0000000000..636c82a5b0 --- /dev/null +++ b/nni/compression/pytorch/quantization_speedup/__init__.py @@ -0,0 +1 @@ +from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT \ No newline at end of file diff --git a/nni/compression/pytorch/quantization_speedup/backend.py b/nni/compression/pytorch/quantization_speedup/backend.py new file mode 100644 index 0000000000..7d139d48f8 --- /dev/null +++ b/nni/compression/pytorch/quantization_speedup/backend.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +class BaseModelSpeedup: + """ + Base speedup class for backend engine + """ + def __init__(self, model, config): + """ + Parameters + ---------- + model : pytorch model + The model to speed up by quantization. + config : dict + Config recording bit number and name of layers. + """ + self.model = model + self.config = config + + def inference(self, test_data): + """ + This function should be overrided by subclass to provide inference ability, + which should return output and inference time. + + Parameters + ---------- + test_data : numpy data + test data given to the inference engine + + Returns + ------- + numpy data + output data will be generated after inference + float + latency of such inference process + """ + raise NotImplementedError('Backend engine must overload inference()') + + def compress(self): + """ + This function should be overrided by subclass to build inference + engine which will be used to process input data + """ + raise NotImplementedError('Backend engine must overload compress()') + + def export_quantized_model(self, path): + """ + This function should be overrided by subclass to build inference + engine which will be used to process input data + """ + raise NotImplementedError('Backend engine must overload export_quantized_model()') \ No newline at end of file diff --git a/nni/compression/pytorch/quantization_speedup/calibrator.py b/nni/compression/pytorch/quantization_speedup/calibrator.py new file mode 100644 index 0000000000..6bc49622f2 --- /dev/null +++ b/nni/compression/pytorch/quantization_speedup/calibrator.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import logging +import tensorrt as trt +import pycuda.driver as cuda + +logger = logging.getLogger(__name__) + +class Calibrator(trt.IInt8Calibrator): + def __init__(self, training_data, cache_file, batch_size=64, algorithm=trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2): + """ + Parameters + ---------- + training_data : numpy array + The data using to calibrate quantization model + cache_file : str + The path user want to store calibrate cache file + batch_size : int + The batch_size of calibrating process + algorithm : tensorrt.tensorrt.CalibrationAlgoType + The algorithms of calibrating contains LEGACY_CALIBRATION, + ENTROPY_CALIBRATION, ENTROPY_CALIBRATION_2, MINMAX_CALIBRATION. + Please refer to https://docs.nvidia.com/deeplearning/tensorrt/api/ + python_api/infer/Int8/Calibrator.html for detail + """ + trt.IInt8Calibrator.__init__(self) + + self.algorithm = algorithm + self.cache_file = cache_file + + self.data = training_data + self.batch_size = batch_size + self.current_index = 0 + + # Allocate enough memory for a whole batch. + self.device_input = cuda.mem_alloc(self.data[0].nbytes * self.batch_size) + + def get_algorithm(self): + return self.algorithm + + def get_batch_size(self): + return self.batch_size + + def get_batch(self, names): + """ + This function is used to define the way of feeding calibrating data each batch. + + Parameters + ---------- + names : str + The names of the network inputs for each object in the bindings array + + Returns + ------- + list + A list of device memory pointers set to the memory containing each network + input data, or an empty list if there are no more batches for calibration. + You can allocate these device buffers with pycuda, for example, and then + cast them to int to retrieve the pointer + """ + if self.current_index + self.batch_size > self.data.shape[0]: + return None + + current_batch = int(self.current_index / self.batch_size) + if current_batch % 10 == 0: + logger.info("Calibrating batch %d, containing %d images", current_batch, self.batch_size) + + batch = self.data[self.current_index:self.current_index + self.batch_size].ravel() + cuda.memcpy_htod(self.device_input, batch) + self.current_index += self.batch_size + memory_pointers = [self.device_input] + return memory_pointers + + def read_calibration_cache(self): + """ + If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None. + + Returns + ------- + cache object + A cache object which contains calibration parameters for quantization + """ + if os.path.exists(self.cache_file): + with open(self.cache_file, "rb") as f: + return f.read() + + def write_calibration_cache(self, cache): + """ + Write calibration cache to specific path. + + Parameters + ---------- + cache : str + The calibration cache to write + """ + with open(self.cache_file, "wb") as f: + f.write(cache) \ No newline at end of file diff --git a/nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py b/nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py new file mode 100644 index 0000000000..2bbb9f17e1 --- /dev/null +++ b/nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import onnx +import onnx.numpy_helper +""" +The main function of this page is to convert pytorch model to onnx model. +Convertion from pytorch model to onnx model is primary so that a critical +problem is caused that Layer name of pytorch model fail to convert to onnx +layer name directly. To solve it, we wrap pytorch model in new wrapper which +multiply bit number and input before computation of each op. Only in this +way can onnx model get bit number of corresponded layer. +""" + +class LayernameModuleWrapper(torch.nn.Module): + def __init__(self, module, module_bit) -> None: + """ + Parameters + ---------- + module : torch.nn.Module + Layer module of pytorch model + module_bit : int + Bit width setting for module + """ + super().__init__() + self.module = module + self.module_bit = module_bit + + def forward(self, inputs): + inputs = inputs*self.module_bit + inputs = self.module(inputs) + return inputs + +def _setattr(model, name, module): + """ + Parameters + ---------- + model : pytorch model + The model to speed up by quantization + name : str + name of pytorch module + module : torch.nn.Module + Layer module of pytorch model + """ + name_list = name.split(".") + for name in name_list[:-1]: + model = getattr(model, name) + setattr(model, name_list[-1], module) + +def unwrapper(model_onnx, index2name, config): + """ + Fill onnx config and remove wrapper node in onnx + + Parameters + ---------- + model_onnx : onnx model + Onnx model which is converted from pytorch model + index2name : dict + Dictionary of layer index and name + config : dict + Config recording name of layers and calibration parameters + + Returns + ------- + onnx model + Onnx model which is converted from pytorch model + dict + The configuration of onnx model layers and calibration parameters + """ + # Support Gemm, Conv, Relu, Clip(Relu6) and Maxpool + support_op = ['Gemm', 'Conv', 'Relu', 'Clip', 'MaxP'] + idx = 0 + onnx_config = {} + while idx < len(model_onnx.graph.node): + nd = model_onnx.graph.node[idx] + if nd.name[0:4] in support_op and idx > 1: + # Grad constant node and multiply node + const_nd = model_onnx.graph.node[idx-2] + mul_nd = model_onnx.graph.node[idx-1] + # Get index number which is transferred by constant node + index = int(onnx.numpy_helper.to_array(const_nd.attribute[0].t)) + if index != -1: + name = index2name[index] + onnx_config[nd.name] = config[name] + nd.input[0] = mul_nd.input[0] + # Remove constant node and multiply node + model_onnx.graph.node.remove(const_nd) + model_onnx.graph.node.remove(mul_nd) + idx = idx-2 + idx = idx+1 + return model_onnx, onnx_config + +def torch_to_onnx(model, config, input_shape, model_path, input_names, output_names): + """ + Convert torch model to onnx model and get layer bit config of onnx model. + + Parameters + ---------- + model : pytorch model + The model to speed up by quantization + config : dict + Config recording bit number and name of layers + input_shape : tuple + The input shape of model, shall pass it to torch.onnx.export + model_path : str + The path user want to store onnx model which is converted from pytorch model + input_names : list + Input name of onnx model providing for torch.onnx.export to generate onnx model + output_name : list + Output name of onnx model providing for torch.onnx.export to generate onnx model + + Returns + ------- + onnx model + Onnx model which is converted from pytorch model + dict + The configuration of onnx model layers and calibration parameters + """ + # Support Gemm, Conv, Relu, Clip(Relu6) and MaxPool + support_op = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.ReLU, torch.nn.ReLU6, torch.nn.MaxPool2d] + # Transfer bit number to onnx layer by using wrapper + index2name = {} + name2index = {} + if config is not None: + for i, name in enumerate(config.keys()): + index2name[i] = name + name2index[name] = i + for name, module in model.named_modules(): + if config is not None and name in config: + assert type(module) in support_op + wrapper_module = LayernameModuleWrapper(module, name2index[name]) + _setattr(model, name, wrapper_module) + elif type(module) in support_op: + wrapper_module = LayernameModuleWrapper(module, -1) + _setattr(model, name, wrapper_module) + # Convert torch model to onnx model and save it in model_path + dummy_input = torch.randn(input_shape) + model.to('cpu') + torch.onnx.export(model, dummy_input, model_path, verbose=False, input_names=input_names, output_names=output_names, export_params=True) + + # Load onnx model + model_onnx = onnx.load(model_path) + model_onnx, onnx_config = unwrapper(model_onnx, index2name, config) + onnx.save(model_onnx, model_path) + + onnx.checker.check_model(model_onnx) + return model_onnx, onnx_config \ No newline at end of file diff --git a/nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py b/nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py new file mode 100644 index 0000000000..c7849774cc --- /dev/null +++ b/nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py @@ -0,0 +1,381 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import time +import logging +import tensorrt as trt +import numpy as np +import torch + +from . import frontend_to_onnx as fonnx +from . import calibrator as calibrator +from . import trt_pycuda as common +from .backend import BaseModelSpeedup + +# TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) +TRT_LOGGER = trt.Logger() +logger = logging.getLogger(__name__) + +class CalibrateType: + LEGACY = trt.CalibrationAlgoType.LEGACY_CALIBRATION + ENTROPY = trt.CalibrationAlgoType.ENTROPY_CALIBRATION + ENTROPY2 = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 + MINMAX = trt.CalibrationAlgoType.MINMAX_CALIBRATION + +Precision_Dict = { + 8: trt.float32, + 16: trt.float16, + 32: trt.float32 +} + +def valid_config(config=None): + """ + This function validates the bit setting configuration + """ + if config is None: + return + support_bit = [8, 16, 32] + for name in config.keys(): + if 'weight_bit' in config[name]: + w_bit = config[name]['weight_bit'] + assert w_bit in support_bit, "weight bit should be 8, 16, 32" + if 'activation_bit' in config[name]: + a_bit = config[name]['activation_bit'] + assert a_bit in support_bit, "activation bit should be 8, 16, 32" + +def handle_gemm(network, layer_idx, config): + """ + This function handles special gemm operation due to layer numbers of gemm changed during pytorch->onnx model convertion. + + Parameters + ---------- + network : tensorrt.INetworkDefinition + Represents a TensorRT Network from which the Builder can build an Engine + layer_idx : int + layer index of gemm + config : dict + Config recording bit number and name of layers + """ + layer = network.get_layer(layer_idx) + pre_layer = network.get_layer(layer_idx-1) + next_layer = network.get_layer(layer_idx+1) + # if weight bit exists, set three layers' precision, + # input tensor range and the first two layers' output type + if 'weight_bit' in config[layer.name]: + assert 'tracked_min_input' in config[layer.name] + assert 'tracked_max_input' in config[layer.name] + w_bit = config[layer.name]['weight_bit'] + tracked_min_input = config[layer.name]['tracked_min_input'] + tracked_max_input = config[layer.name]['tracked_max_input'] + # set three layers the same precision + layer.precision = Precision_Dict[w_bit] + pre_layer.precision = Precision_Dict[w_bit] + next_layer.precision = Precision_Dict[w_bit] + # set the first two layers' output type + pre_layer.set_output_type(0, Precision_Dict[w_bit]) + layer.set_output_type(0, Precision_Dict[w_bit]) + pre_in_tensor = pre_layer.get_input(0) + in_tensor = layer.get_input(0) + next_in_tensor = next_layer.get_input(0) + # set three layers' input tensor range + pre_in_tensor.dynamic_range = (tracked_min_input, tracked_max_input) + in_tensor.dynamic_range = (tracked_min_input, tracked_max_input) + next_in_tensor.dynamic_range = (tracked_min_input, tracked_max_input) + + # if activation bit exists, set the last layer's output type output tensor range + if 'activation_bit' in config[layer.name]: + assert 'tracked_min_activation' in config[layer.name] + assert 'tracked_max_activation' in config[layer.name] + a_bit = config[layer.name]['activation_bit'] + tracked_min_activation = config[layer.name]['tracked_min_activation'] + tracked_max_activation = config[layer.name]['tracked_max_activation'] + # set the last layer's output type + next_layer.set_output_type(0, Precision_Dict[a_bit]) + next_out_tensor = next_layer.get_output(0) + # set the last layer's output tensor range + next_out_tensor.dynamic_range = (tracked_min_activation, tracked_max_activation) + +def build_engine(model_file, config=None, extra_layer_bit=32, strict_datatype=False, calib=None): + """ + This function builds an engine from an onnx model with calibration process. + + Parameters + ---------- + model_file : str + The path of onnx model + config : dict + Config recording bit number and name of layers + extra_layer_bit : int + Other layers which are not in config will be quantized to corresponding bit number + strict_datatype : bool + Whether constrain layer bit to the number given in config or not. If true, all the layer + will be set to given bit strictly. Otherwise, these layers will be set automatically by + tensorrt + calib : numpy array + The data using to calibrate quantization model + + Returns + ------- + tensorrt.ICudaEngine + An ICudaEngine for executing inference on a built network + """ + with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, \ + trt.OnnxParser(network, TRT_LOGGER) as parser: + # Attention that, builder should be set to 1 because of the implementation of allocate_buffer + builder.max_batch_size = 1 + builder.max_workspace_size = common.GiB(4) + + if extra_layer_bit == 32 and config is None: + pass + elif extra_layer_bit == 16 and config is None: + builder.fp16_mode = True + elif extra_layer_bit == 8 and config is None: + # entire model in 8bit mode + builder.int8_mode = True + else: + builder.int8_mode = True + builder.fp16_mode = True + builder.strict_type_constraints = strict_datatype + + valid_config(config) + + # Parse onnx model + with open(model_file, 'rb') as model: + if not parser.parse(model.read()): + logger.error('ERROR: Fail to parse the ONNX file.') + for error in range(parser.num_errors): + logger.error(parser.get_error(error)) + return None + + if calib is not None: + builder.int8_calibrator = calib + # This design may not be correct if output more than one + for i in range(network.num_layers): + if config is None: + break + layer = network.get_layer(i) + if layer.name in config: + w_bit = config[layer.name]['weight_bit'] + a_bit = config[layer.name]['activation_bit'] + layer.precision = Precision_Dict[w_bit] + layer.set_output_type(0, Precision_Dict[a_bit]) + else: + # This implementation may be incorrect when output number > 1 + for i in range(network.num_layers): + if config is None: + # no low bit layer need to be set, keep original model + break + layer = network.get_layer(i) + if layer.name not in config: + continue + # layer numbers of gemm changed during pytorch->onnx model convertion, need special handle + if layer.name[0:4] == "Gemm": + handle_gemm(network, i, config) + continue + + # If weight_bit exists in config, set layer precision and layer's input tensor dynamic range. + if 'weight_bit' in config[layer.name]: + assert 'tracked_min_input' in config[layer.name] + assert 'tracked_max_input' in config[layer.name] + w_bit = config[layer.name]['weight_bit'] + tracked_min_input = config[layer.name]['tracked_min_input'] + tracked_max_input = config[layer.name]['tracked_max_input'] + layer.precision = Precision_Dict[w_bit] + in_tensor = layer.get_input(0) + in_tensor.dynamic_range = (tracked_min_input, tracked_max_input) + + # If activation exists in config, set layer output type and layer's output tensor dynamic range. + if 'activation_bit' in config[layer.name]: + assert 'tracked_min_activation' in config[layer.name] + assert 'tracked_max_activation' in config[layer.name] + a_bit = config[layer.name]['activation_bit'] + tracked_min_activation = config[layer.name]['tracked_min_activation'] + tracked_max_activation = config[layer.name]['tracked_max_activation'] + layer.set_output_type(0, Precision_Dict[a_bit]) + out_tensor = layer.get_output(0) + out_tensor.dynamic_range = (tracked_min_activation, tracked_max_activation) + + # Build engine and do int8 calibration. + engine = builder.build_cuda_engine(network) + return engine + +class ModelSpeedupTensorRT(BaseModelSpeedup): + def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bit=32, strict_datatype=True, + calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1, + input_names=["actual_input_1"], output_names=["output1"]): + """ + Parameters + ---------- + model : pytorch model + The model to speed up by quantization. + input_shape : tuple + The input shape of model, shall pass it to torch.onnx.export. + config : dict + Config recording bit number and name of layers. + onnx_path : str + The path user want to store onnx model which is converted from pytorch model. + extra_layer_bit : int + Other layers which are not in config will be quantized to corresponding bit number. + strict_datatype : bool + Whether constrain layer bit to the number given in config or not. If true, all the layer + will be set to given bit strictly. Otherwise, these layers will be set automatically by + tensorrt. + calibrate_type : tensorrt.tensorrt.CalibrationAlgoType + The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/ + tensorrt/api/python_api/infer/Int8/Calibrator.html for detail + calibrate_data : numpy array + The data using to calibrate quantization model + calibration_cache : str + The path user want to store calibrate cache file + batchsize : int + The batch size of calibration and inference + input_names : list + Input name of onnx model providing for torch.onnx.export to generate onnx model + output_name : list + Output name of onnx model providing for torch.onnx.export to generate onnx model + """ + super().__init__(model, config) + self.model = model + self.onnx_path = onnx_path + self.input_shape = input_shape + self.config = config + self.extra_layer_bit = extra_layer_bit + self.strict_datatype = strict_datatype + self.calibrate_type = calibrate_type + self.calib_data_loader = calib_data_loader + self.calibration_cache = calibration_cache + self.batchsize = batchsize + self.input_names = input_names + self.output_names = output_names + self.context = None + self.onnx_config = {} + + def compress(self): + """ + Get onnx config and build tensorrt engine. + """ + assert self.model is not None + assert self.onnx_path is not None + assert self.input_shape is not None + + # Convert pytorch model to onnx model and save onnx model in onnx_path + _, self.onnx_config = fonnx.torch_to_onnx(self.model, self.config, input_shape=self.input_shape, + model_path=self.onnx_path, input_names=self.input_names, output_names=self.output_names) + + if self.calib_data_loader is not None: + assert self.calibrate_type is not None + context = self._tensorrt_build_withcalib(self.onnx_path) + else: + context = self._tensorrt_build_withoutcalib(self.onnx_path) + self.context = context + + def _tensorrt_build_withcalib(self, onnx_path): + """ + Convert pytorch tensor to numpy darray + + Parameters + ---------- + onnx_path : str + The path of onnx model + + Returns + ------- + tensorrt.IExecutionContext + Context for executing inference using an ICudaEngine + """ + calib_data = None + if type(self.calib_data_loader) == torch.utils.data.dataloader.DataLoader: + calib_data_set = [] + for data, _ in self.calib_data_loader: + calib_data_set.append(data) + calib_data = np.concatenate(calib_data_set) + elif type(self.calib_data_loader) == torch.Tensor: + calib_data = self.calib_data_loader.numpy() + else: + raise ValueError("Not support calibration datatype") + calib = calibrator.Calibrator(calib_data, self.calibration_cache, self.batchsize, self.calibrate_type) + + # build inference engine with calibration + engine = build_engine(onnx_path, self.onnx_config, self.extra_layer_bit, self.strict_datatype, calib) + return engine.create_execution_context() + + def _tensorrt_build_withoutcalib(self, onnx_path): + """ + Build inference engine without calibration + + Parameters + ---------- + onnx_path : str + The path of onnx model + + Returns + ------- + tensorrt.IExecutionContext + Context for executing inference using an ICudaEngine + """ + engine = build_engine(onnx_path, self.onnx_config, self.extra_layer_bit, self.strict_datatype) + return engine.create_execution_context() + + def inference(self, test_data): + """ + Do inference by tensorrt builded engine. + + Parameters + ---------- + test_data : pytorch tensor + Model input tensor + """ + # convert pytorch tensor to numpy darray + test_data = test_data.numpy() + # Numpy dtype should be float32 + assert test_data.dtype == np.float32 + elapsed_time = 0 + inputs, outputs, bindings, stream = common.allocate_buffers(self.context.engine) + result = [] + for start_idx in range(0, test_data.shape[0], self.batchsize): + # If the number of images in the test set is not divisible by the batch size, the last batch will be smaller. + # This logic is used for handling that case. + end_idx = min(start_idx + self.batchsize, test_data.shape[0]) + effective_batch_size = end_idx - start_idx + + # Do inference for every batch. + inputs[0].host = test_data[start_idx:start_idx + effective_batch_size] + t1 = time.time() + [output] = common.do_inference_v2(self.context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream) + elapsed_time += time.time() - t1 + shape = output.shape[0] + output = output[0:int(shape * effective_batch_size / self.batchsize)].reshape(effective_batch_size, -1) + result.append(output.copy()) + # Use argmax to get predictions and then check accuracy + # convert numpy darray to pytorch tensor + result = torch.Tensor(np.concatenate(result)) + return result, elapsed_time + + def export_quantized_model(self, path): + """ + Export TensorRT quantized model engine which only can be loaded by TensorRT deserialize API. + + Parameters + ---------- + path : str + The path of export model + """ + assert path is not None + with open(path, "wb") as f: + f.write(self.context.engine.serialize()) + logger.info("TensorRT engine has been saved to %s", path) + + def load_quantized_model(self, path): + """ + Load TensorRT quantized model engine from specific path. + + Parameters + ---------- + path : str + The path of export model + """ + assert path is not None + with open(path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: + engine = runtime.deserialize_cuda_engine(f.read()) + self.context = engine.create_execution_context() + logger.info("Load TensorRT engine from %s successfully.", path) \ No newline at end of file diff --git a/nni/compression/pytorch/quantization_speedup/trt_pycuda.py b/nni/compression/pytorch/quantization_speedup/trt_pycuda.py new file mode 100644 index 0000000000..d3f8e1f4c6 --- /dev/null +++ b/nni/compression/pytorch/quantization_speedup/trt_pycuda.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pycuda.driver as cuda +import pycuda.autoinit # pylint: disable=unused-import +import tensorrt as trt + +EXPLICIT_BATCH = 1 + +def GiB(val): + return val * 1 << 30 + +# Simple helper data class that's a little nicer to use than a 2-tuple. +class HostDeviceMem(object): + def __init__(self, host_mem, device_mem): + """ + This function builds an engine from an onnx model with calibration process. + + Parameters + ---------- + host_mem : host memory + Memory buffers of host + device_mem : device memory + Memory buffers of device + """ + self.host = host_mem + self.device = device_mem + + def __str__(self): + return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) + + def __repr__(self): + return self.__str__() + +def allocate_buffers(engine): + """ + Allocates all buffers required for an engine, i.e. host/device inputs/outputs. + + Parameters + ---------- + engine : tensorrt.ICudaEngine + An ICudaEngine for executing inference on a built network + + Returns + ------- + list + All input HostDeviceMem of an engine + list + All output HostDeviceMem of an engine + GPU bindings + Device bindings + GPU stream + A stream is a sequence of commands (possibly issued by different host threads) that execute in order + """ + inputs = [] + outputs = [] + bindings = [] + stream = cuda.Stream() + for binding in engine: + size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size + dtype = trt.nptype(engine.get_binding_dtype(binding)) + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(device_mem)) + # Append to the appropriate list. + if engine.binding_is_input(binding): + inputs.append(HostDeviceMem(host_mem, device_mem)) + else: + outputs.append(HostDeviceMem(host_mem, device_mem)) + return inputs, outputs, bindings, stream + +# This function is generalized for multiple inputs/outputs for full dimension networks. +# inputs and outputs are expected to be lists of HostDeviceMem objects. +def do_inference_v2(context, bindings, inputs, outputs, stream): + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] + # Run inference. + context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] + # Synchronize the stream + stream.synchronize() + # Return only the host outputs. + return [out.host for out in outputs] \ No newline at end of file diff --git a/pylintrc b/pylintrc index ccefe24702..e2965c706d 100644 --- a/pylintrc +++ b/pylintrc @@ -45,6 +45,6 @@ enable= unused-wildcard-import, ignore-patterns=test* # List of members which are set dynamically and missed by pylint inference -generated-members=numpy.*,torch.*,tensorflow.* +generated-members=numpy.*,torch.*,tensorflow.*,pycuda.*,tensorrt.* -ignored-modules=tensorflow,_winapi,msvcrt +ignored-modules=tensorflow,_winapi,msvcrt,tensorrt,pycuda diff --git a/test/ut/sdk/test_compressor_torch.py b/test/ut/sdk/test_compressor_torch.py index 11a1eb7e30..b7b0b2019e 100644 --- a/test/ut/sdk/test_compressor_torch.py +++ b/test/ut/sdk/test_compressor_torch.py @@ -239,15 +239,16 @@ def test_torch_QAT_quantizer(self): # test quantize # range not including 0 eps = 1e-7 + input = torch.tensor([[0, 4], [2, 1]]).float() weight = torch.tensor([[1, 2], [3, 5]]).float() model.conv2.module.old_weight.data = weight - quantizer.quantize_weight(model.conv2) + quantizer.quantize_weight(model.conv2, input_tensor=input) assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps) assert model.conv2.module.zero_point == 0 # range including 0 weight = torch.tensor([[-1, 2], [3, 5]]).float() model.conv2.module.old_weight.data = weight - quantizer.quantize_weight(model.conv2) + quantizer.quantize_weight(model.conv2, input_tensor=input) assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps) assert model.conv2.module.zero_point in (42, 43) # test value of weight and bias after quantization @@ -257,7 +258,7 @@ def test_torch_QAT_quantizer(self): bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341]) model.conv2.module.old_weight.data = weight model.conv2.module.bias.data = bias - quantizer.quantize_weight(model.conv2) + quantizer.quantize_weight(model.conv2, input_tensor=input) assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4)) assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7)) @@ -265,14 +266,14 @@ def test_torch_QAT_quantizer(self): eps = 1e-7 x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) out = model.relu(x) - assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps) - assert math.isclose(model.relu.module.tracked_max_biased, 0.002, abs_tol=eps) + assert math.isclose(model.relu.module.tracked_min_activation, 0, abs_tol=eps) + assert math.isclose(model.relu.module.tracked_max_activation, 0.002, abs_tol=eps) quantizer.step_with_optimizer() x = torch.tensor([[0.2, 0.4], [0.6, 0.8]]) out = model.relu(x) - assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps) - assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps) + assert math.isclose(model.relu.module.tracked_min_activation, 0.002, abs_tol=eps) + assert math.isclose(model.relu.module.tracked_max_activation, 0.00998, abs_tol=eps) def test_torch_quantizer_export(self): config_list_qat = [{