Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Combine tensorrt tool with NNI quantization algorithms. #3488

Merged
merged 21 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions examples/model_compress/quantization/mixed_precision_speedup_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT

class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
self.max_pool2 = torch.nn.MaxPool2d(2, 2)

def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.max_pool1(x)
x = self.relu2(self.conv2(x))
x = self.max_pool2(x)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))

def test_trt(engine, test_loader):
test_loss = 0
correct = 0
time_elasped = 0
for data, target in test_loader:
output, time = engine.inference(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
time_elasped += time
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))

def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()

config = {
'conv1':{'weight_bit':8, 'activation_bit':8},
'conv2':{'weight_bit':32, 'activation_bit':32},
'fc1':{'weight_bit':16, 'activation_bit':16},
'fc2':{'weight_bit':8, 'activation_bit':8}
}

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

model.to(device)
for epoch in range(1):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)

batch_size = 32
input_shape = (batch_size, 1, 28, 28)

engine = ModelSpeedupTensorRT(model, input_shape, config=config, calib_data_loader=train_loader, batchsize=batch_size)
engine.compress()
test_trt(engine, test_loader)

if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT

class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.maxpool1 = torch.nn.MaxPool2d(2, 2)
self.maxpool2 = torch.nn.MaxPool2d(2, 2)

def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.maxpool1(x)
x = self.relu2(self.conv2(x))
x = self.maxpool2(x)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))

def test_trt(engine, test_loader):
test_loss = 0
correct = 0
time_elasped = 0
for data, target in test_loader:
output, time = engine.inference(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
time_elasped += time
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))

def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()

configure_list = [{
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output':8},
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
'quant_bits': {'output':8},
'op_names': ['relu2']
}
]

# finetune the model by using QAT
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, configure_list, optimizer)
quantizer.compress()

model.to(device)
for epoch in range(1):
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)

model_path = "mnist_model.pth"
calibration_path = "mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)

test(model, device, test_loader)

print("calibration_config: ", calibration_config)

batch_size = 32
input_shape = (batch_size, 1, 28, 28)

engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size)
engine.compress()

test_trt(engine, test_loader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems this example includes the above example, so why need two examples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In example mixed_precision_speedup_mnist.py , model will be quantized in tensorrt directly by providing calibration dataset and tensorrt will get quantization parameter by calibration process. We can consider it as post-training quantization.
However, in example mixed_precision_speedup_mnist_QAT.py, we first finetune the model and get quantization parameters by using QAT algorithm. Then this model will be quantized in tensorrt without calibration dataset. We can consider it as Quantization aware training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to put them into one example file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have put them into one file.


if __name__ == '__main__':
main()
45 changes: 28 additions & 17 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, wrapper, **kwargs):
def quantize_weight(self, input, wrapper, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of input? and is it used in this function?...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this new version, input tensor's dynamic range will also be recorded to meet the requirement of tensorrt tensor range setting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so what is the meaning of input?

Copy link
Contributor Author

@linbinskn linbinskn Apr 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input is input tensor of this op. It is used to calibrate the input tensor dynamic range. It won't be used in all quantizers so I have passed it by kwargs.

weight = copy.deepcopy(wrapper.module.old_weight.data)
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
Expand Down Expand Up @@ -152,22 +152,23 @@ def __init__(self, model, config_list, optimizer=None):
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", torch.Tensor([1.0]))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))
layer.module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('activation_bit', torch.zeros(1))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
layer.module.register_buffer('tracked_min', torch.zeros(1))
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))
layer.module.register_buffer('tracked_min_activation', torch.zeros(1))
layer.module.register_buffer('tracked_max_activation', torch.zeros(1))


def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_biased', 'tracked_max_biased', 'tracked_min', \
'tracked_max', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
Expand Down Expand Up @@ -240,7 +241,7 @@ def _dequantize(self, op, quantized_val):
real_val = op.scale * (quantized_val - op.zero_point)
return real_val

def quantize_weight(self, wrapper, **kwargs):
def quantize_weight(self, input, wrapper, **kwargs):
config = wrapper.config
module = wrapper.module
weight = copy.deepcopy(wrapper.module.old_weight.data)
Expand All @@ -250,8 +251,16 @@ def quantize_weight(self, wrapper, **kwargs):

# we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps or not wrapper.training:
module.tracked_min_input, module.tracked_max_input = torch.min(input), torch.max(input)
return weight

if wrapper.training:
current_min, current_max = torch.min(input), torch.max(input)
module.tracked_min_input = update_ema(module.tracked_min_input, current_min,
module.ema_decay)
module.tracked_max_input = update_ema(module.tracked_max_input, current_max,
module.ema_decay)

# if bias exists, quantize bias to uint32
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
bias = wrapper.module.bias.data
Expand Down Expand Up @@ -281,17 +290,17 @@ def quantize_output(self, output, wrapper, **kwargs):
assert output_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.bound_model.steps:
module.tracked_min_biased, module.tracked_max_biased = torch.min(output), torch.max(output)
module.tracked_min_activation, module.tracked_max_activation = torch.min(output), torch.max(output)
return output

# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased = update_ema(module.tracked_min_biased, current_min,
module.tracked_min_activation = update_ema(module.tracked_min_activation, current_min,
module.ema_decay)
module.tracked_max_biased = update_ema(module.tracked_max_biased, current_max,
module.tracked_max_activation = update_ema(module.tracked_max_activation, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_biased, module.tracked_max_biased)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_activation, module.tracked_max_activation)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
return out
Expand Down Expand Up @@ -327,10 +336,12 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_
calibration_config[name] = {}
if hasattr(module, 'weight_bit'):
calibration_config[name]['weight_bit'] = int(module.weight_bit)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
if hasattr(module, 'activation_bit'):
calibration_config[name]['activation_bit'] = int(module.activation_bit)
calibration_config[name]['tracked_min'] = float(module.tracked_min_biased)
calibration_config[name]['tracked_max'] = float(module.tracked_max_biased)
calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation)
calibration_config[name]['tracked_max_activation'] = float(module.tracked_max_activation)
self._del_simulated_attr(module)

self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
Expand Down Expand Up @@ -390,7 +401,7 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, wrapper, **kwargs):
def quantize_weight(self, input, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight_bits = get_bits_length(wrapper.config, 'weight')
weight = weight.tanh()
Expand Down Expand Up @@ -496,7 +507,7 @@ def validate_config(self, model, config_list):

schema.validate(config_list)

def quantize_weight(self, wrapper, **kwargs):
def quantize_weight(self, input, wrapper, **kwargs):
weight = copy.deepcopy(wrapper.module.old_weight.data)
weight = torch.sign(weight)
# remove zeros
Expand Down
6 changes: 3 additions & 3 deletions nni/compression/pytorch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def forward(self, *inputs):
self.quantizer.quant_grad.apply(
self.module.old_weight,
QuantType.QUANT_WEIGHT,
self)
self, inputs[0])
result = self.module(*inputs)
else:
result = self.module(*inputs)
Expand Down Expand Up @@ -720,11 +720,11 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma
return grad_output

@staticmethod
def forward(ctx, tensor, quant_type, wrapper, **kwargs):
def forward(ctx, tensor, quant_type, wrapper, tensor_alt=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of tensor_alt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used to transfer the second tensor during the forward process. In previous implementation, we only transfer one tensor like weight, input, output. But in some situation, we need to transfer two tensors to calibrate both of them like the function quantize_weight. So I add it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is tensor_alt commonly used in different quantizers? if it is specific for some quantizers, suggest to put it in kwargs, this is howkwargs used for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument tensor_alt may not be used by most of quantizers but I don't think it is a bad nothing to provide an alternative preparation here. What's more, forward in class QuantGrad is called by apply() which only supports positional argument so that kwargs may be nothing here. If we put it as kwargs forcibly, error would be raised. Based on the above reasons, I think it can be kept.

if quant_type == QuantType.QUANT_INPUT:
output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_WEIGHT:
output = wrapper.quantizer.quantize_weight(wrapper, **kwargs)
output = wrapper.quantizer.quantize_weight(tensor_alt, wrapper, **kwargs)
elif quant_type == QuantType.QUANT_OUTPUT:
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs)
else:
Expand Down
1 change: 1 addition & 0 deletions nni/compression/pytorch/quantization_speedup/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT
Loading