Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Combine tensorrt tool with NNI quantization algorithms. (#3488)
Browse files Browse the repository at this point in the history
  • Loading branch information
linbinskn authored Apr 9, 2021
1 parent 80bc953 commit f0e3c58
Show file tree
Hide file tree
Showing 11 changed files with 977 additions and 29 deletions.
169 changes: 169 additions & 0 deletions examples/model_compress/quantization/mixed_precision_speedup_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT

class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
self.max_pool2 = torch.nn.MaxPool2d(2, 2)

def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.max_pool1(x)
x = self.relu2(self.conv2(x))
x = self.max_pool2(x)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))

def test_trt(engine, test_loader):
test_loss = 0
correct = 0
time_elasped = 0
for data, target in test_loader:
output, time = engine.inference(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
time_elasped += time
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))

def post_training_quantization_example(train_loader, test_loader, device):
model = Mnist()

config = {
'conv1':{'weight_bit':8, 'activation_bit':8},
'conv2':{'weight_bit':32, 'activation_bit':32},
'fc1':{'weight_bit':16, 'activation_bit':16},
'fc2':{'weight_bit':8, 'activation_bit':8}
}

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

model.to(device)
for epoch in range(1):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)

batch_size = 32
input_shape = (batch_size, 1, 28, 28)

engine = ModelSpeedupTensorRT(model, input_shape, config=config, calib_data_loader=train_loader, batchsize=batch_size)
engine.compress()
test_trt(engine, test_loader)

def quantization_aware_training_example(train_loader, test_loader, device):
model = Mnist()

configure_list = [{
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output':8},
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
'quant_bits': {'output':8},
'op_names': ['relu2']
}
]

# finetune the model by using QAT
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, configure_list, optimizer)
quantizer.compress()

model.to(device)
for epoch in range(1):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)

model_path = "mnist_model.pth"
calibration_path = "mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)

test(model, device, test_loader)

print("calibration_config: ", calibration_config)

batch_size = 32
input_shape = (batch_size, 1, 28, 28)

engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size)
engine.compress()

test_trt(engine, test_loader)

def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

# post-training quantization on TensorRT
post_training_quantization_example(train_loader, test_loader, device)

# combine NNI quantization algorithm QAT with backend framework TensorRT
quantization_aware_training_example(train_loader, test_loader, device)

if __name__ == '__main__':
main()
42 changes: 28 additions & 14 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,23 @@ def __init__(self, model, config_list, optimizer=None):
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", torch.Tensor([1.0]))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))
layer.module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('activation_bit', torch.zeros(1))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
layer.module.register_buffer('tracked_min', torch.zeros(1))
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))
layer.module.register_buffer('tracked_min_activation', torch.zeros(1))
layer.module.register_buffer('tracked_max_activation', torch.zeros(1))


def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_biased', 'tracked_max_biased', 'tracked_min', \
'tracked_max', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
Expand Down Expand Up @@ -243,15 +244,26 @@ def _dequantize(self, op, quantized_val):
def quantize_weight(self, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
input = kwargs['input_tensor']
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"

# we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps or not wrapper.training:
if quant_start_step > self.bound_model.steps:
module.tracked_min_input, module.tracked_max_input = torch.min(input), torch.max(input)
return weight

if not wrapper.training:
return weight

current_min, current_max = torch.min(input), torch.max(input)
module.tracked_min_input = update_ema(module.tracked_min_input, current_min,
module.ema_decay)
module.tracked_max_input = update_ema(module.tracked_max_input, current_max,
module.ema_decay)

# if bias exists, quantize bias to uint32
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
bias = wrapper.module.bias.data
Expand Down Expand Up @@ -281,17 +293,17 @@ def quantize_output(self, output, wrapper, **kwargs):
assert output_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.bound_model.steps:
module.tracked_min_biased, module.tracked_max_biased = torch.min(output), torch.max(output)
module.tracked_min_activation, module.tracked_max_activation = torch.min(output), torch.max(output)
return output

# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased = update_ema(module.tracked_min_biased, current_min,
module.tracked_min_activation = update_ema(module.tracked_min_activation, current_min,
module.ema_decay)
module.tracked_max_biased = update_ema(module.tracked_max_biased, current_max,
module.tracked_max_activation = update_ema(module.tracked_max_activation, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_biased, module.tracked_max_biased)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_activation, module.tracked_max_activation)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
return out
Expand Down Expand Up @@ -327,10 +339,12 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
calibration_config[name] = {}
if hasattr(module, 'weight_bit'):
calibration_config[name]['weight_bit'] = int(module.weight_bit)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
calibration_config[name]['tracked_min'] = float(module.tracked_min_biased)
calibration_config[name]['tracked_max'] = float(module.tracked_max_biased)
calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation)
calibration_config[name]['tracked_max_activation'] = float(module.tracked_max_activation)
self._del_simulated_attr(module)

self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
Expand Down
10 changes: 4 additions & 6 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def forward(self, *inputs):
self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self)
self, inputs[0])
result = self.module(*inputs)
else:
result = self.module(*inputs)
Expand Down Expand Up @@ -511,14 +511,12 @@ def __init__(self, model, config_list, optimizer=None):
# and it is trainable, therefore, it should be added to optimizer.
self.optimizer.add_param_group({"params": wrapper.module.old_weight})

def quantize_weight(self, weight, wrapper, **kwargs):
def quantize_weight(self, wrapper, **kwargs):
"""
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
weight : Tensor
weight that needs to be quantized
wrapper : QuantizerModuleWrapper
the wrapper for origin module
"""
Expand Down Expand Up @@ -720,11 +718,11 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma
return grad_output

@staticmethod
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs):
if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, **kwargs)
output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
Expand Down
1 change: 1 addition & 0 deletions nni/compression/pytorch/quantization_speedup/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT
51 changes: 51 additions & 0 deletions nni/compression/pytorch/quantization_speedup/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

class BaseModelSpeedup:
"""
Base speedup class for backend engine
"""
def __init__(self, model, config):
"""
Parameters
----------
model : pytorch model
The model to speed up by quantization.
config : dict
Config recording bit number and name of layers.
"""
self.model = model
self.config = config

def inference(self, test_data):
"""
This function should be overrided by subclass to provide inference ability,
which should return output and inference time.
Parameters
----------
test_data : numpy data
test data given to the inference engine
Returns
-------
numpy data
output data will be generated after inference
float
latency of such inference process
"""
raise NotImplementedError('Backend engine must overload inference()')

def compress(self):
"""
This function should be overrided by subclass to build inference
engine which will be used to process input data
"""
raise NotImplementedError('Backend engine must overload compress()')

def export_quantized_model(self, path):
"""
This function should be overrided by subclass to build inference
engine which will be used to process input data
"""
raise NotImplementedError('Backend engine must overload export_quantized_model()')
Loading

0 comments on commit f0e3c58

Please sign in to comment.