Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Combine tensorrt tool with NNI quantization algorithms. #3488

Merged
merged 21 commits into from
Apr 9, 2021

Conversation

linbinskn
Copy link
Contributor

NNI tensorrt support
Target:
1. Support real quantization speed up for NNI for different hardware(now only support TensorRT)
2. Support mixed precision search specially for mixed quantization(design interface)
3. Combine quantization inference and current simulated quantization interface in NNI, mainly support QAT.



def resnet18(**kwargs):
return _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to create a models folder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good Point. Will create this model by importing from folder after the position of model compression models folder in examples is comfirmed.

engine.compress()
output, time = engine.inference(test_set)

check_accuracy(output, test_labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not inference on the full test dataset and compare the accuracy of the quantized model with the original model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The evaluation dataset has already been full test dataset. For the current scenario training QAT model from scratch in this example, we should compare accuracy of QAT quantized model and accuracy of speedup model, both of which have been printed.

'layer4.1.conv1':{'weight_bit':8, 'activation_bit':8},
'layer4.1.conv2':{'weight_bit':8, 'activation_bit':8},
'fc':{'weight_bit':8, 'activation_bit':8},
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we specify one bit for all layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't supported it because not all of ops have been supported in quantization. But we can specify one bit for specific supported op type and all layers of this op will be quantized.

self.algorithm = algorithm
self.cache_file = cache_file

# Every time get_batch is called, the next batch of size batch_size will be copied to the device and returned.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment looks strange here, move to get_match would be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is for self.batch_size. Have modified to make it clear.

return None

input_tensor = network.get_input(0)
input_tensor.dynamic_range = (-100, 100)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why input dynamic range set to (-100, 100)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Range (-100, 100) is just for testing. It has been deleted in latest commit.

engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size)
engine.compress()

test_trt(engine, test_loader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems this example includes the above example, so why need two examples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In example mixed_precision_speedup_mnist.py , model will be quantized in tensorrt directly by providing calibration dataset and tensorrt will get quantization parameter by calibration process. We can consider it as post-training quantization.
However, in example mixed_precision_speedup_mnist_QAT.py, we first finetune the model and get quantization parameters by using QAT algorithm. Then this model will be quantized in tensorrt without calibration dataset. We can consider it as Quantization aware training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to put them into one example file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have put them into one file.

@@ -31,7 +31,7 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, wrapper, **kwargs):
def quantize_weight(self, input, wrapper, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of input? and is it used in this function?...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this new version, input tensor's dynamic range will also be recorded to meet the requirement of tensorrt tensor range setting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what is the meaning of input?

Copy link
Contributor Author

@linbinskn linbinskn Apr 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input is input tensor of this op. It is used to calibrate the input tensor dynamic range. It won't be used in all quantizers so I have passed it by kwargs.

@@ -720,11 +720,11 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma
return grad_output

@staticmethod
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
def forward(ctx, tensor, quant_type, wrapper, tensor_alt=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of tensor_alt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used to transfer the second tensor during the forward process. In previous implementation, we only transfer one tensor like weight, input, output. But in some situation, we need to transfer two tensors to calibrate both of them like the function quantize_weight. So I add it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is tensor_alt commonly used in different quantizers? if it is specific for some quantizers, suggest to put it in kwargs, this is howkwargs used for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument tensor_alt may not be used by most of quantizers but I don't think it is a bad nothing to provide an alternative preparation here. What's more, forward in class QuantGrad is called by apply() which only supports positional argument so that kwargs may be nothing here. If we put it as kwargs forcibly, error would be raised. Based on the above reasons, I think it can be kept.

return self.__str__()

# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
def allocate_buffers(engine):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a little strange to put these tensorrt specific functions to common.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All functions in it are about cuda memory operation. I think it is better to take them out to make code easier understand.

Copy link
Contributor

@QuanluZhang QuanluZhang Apr 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i agree, so you can rename this file, for example, "trt_cuda.py". because we are supposed to support different backends, cuda memory operations are still specific for nvidia gpu.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have mv common.py to 'trt_pycuda.py'.

engine = build_engine(onnx_path, calib, self.onnx_config, self.extra_layer_bit, self.strict_datatype)
return engine.create_execution_context()

def tensorrt_build_withoutcalib(self, onnx_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a member function is not supposed to be exposed to users, it would be better to add _ before function name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Have added.

@QuanluZhang
Copy link
Contributor

@linbinskn please update doc accordingly. And prepare unit test, we will setup environment for this unit test.

linbinskn added a commit to linbinskn/nni that referenced this pull request Apr 5, 2021
@linbinskn
Copy link
Contributor Author

@linbinskn please update doc accordingly. And prepare unit test, we will setup environment for this unit test.

Have updated doc in #3512 . Unit test is also prepared and will be pushed after environment setup.

engine = builder.build_cuda_engine(network)
return engine

def build_engine_without_calib(model_file, config=None, extra_layer_bit=32, strict_datatype=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to combine this function with "build_engine"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have combined them into one function.


if extra_layer_bit == 32 and config is None:
pass
elif extra_layer_bit == 8 and config is None:
Copy link
Contributor

@QuanluZhang QuanluZhang Apr 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if extra_layer_bit is 16 and config is None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should turn on fp16 mode. Have modified.

else:
builder.int8_mode = True
builder.fp16_mode = True
builder.int8_calibrator = calib
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so calib is only for int8? what about 2 bits? 4 bits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only int8 is supported in TensorRT. The int8_calibrator parameter is fixed in TensorRT builder.

# Parse onnx model
with open(model_file, 'rb') as model:
if not parser.parse(model.read()):
print ('ERROR: Fail to parse the ONNX file.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print -> logging

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have substituted.

if layer.name in config:
w_bit = config[layer.name]['weight_bit']
a_bit = config[layer.name]['activation_bit']
layer.precision = Precision_Dict[w_bit]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible that w_bit is a value other than 8, 16, 32? better to add a validation function of config to trt backend

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added validate function.

# entire model in 8bit mode
builder.int8_mode = True
else:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is too hacky

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it is a mistake. Have fixed it.

import pycuda.autoinit
import tensorrt as trt

pycuda.autoinit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This have to be kept because import pycuda.autoinit is necessary here otherwise pycuda would not be ready and error would be raised. But itself will not be used in following code which is not allowed in python test pipeline. So I choose to put this sentence here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use comment to escape pylint for that line

# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference_v2(context, bindings, inputs, outputs, stream):
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a code style or for some reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a code style learned from NVIDIA trt example.

"""
# Attention that, builder should be set to 1 because of the implementation of allocate_buffer
builder.max_batch_size = 1
builder.max_workspace_size = common.GiB(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is fixing size to 1GiB enough for all scenes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! I think 1GiB is enough for single model. To prevent memory limitation in some special cases, I extended it to 4GiB.

@linbinskn linbinskn requested review from QuanluZhang and J-shang April 7, 2021 02:50
for i in range(network.num_layers):
if config is None:
break
valid_config(config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid_config should be called many times?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have modified.

Input name of onnx model providing for torch.onnx.export to generate onnx model
output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model
Returns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a blank line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added.

in_tensor = layer.get_input(0)
in_tensor.dynamic_range = (tracked_min_input, tracked_max_input)
# Gemm will generate two shuffle layers before and after itself, need specific setting
if layer.name[0:4] == "Gemm":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why "Gemm" is handled only when calib is None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When calib is not none, quantization speedup module will do post training quantization. In current implementation, we do not consider any extra modification to post training quantization.

@linbinskn linbinskn requested a review from QuanluZhang April 8, 2021 08:13
@QuanluZhang QuanluZhang merged commit f0e3c58 into microsoft:master Apr 9, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants