-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Combine tensorrt tool with NNI quantization algorithms. #3488
Changes from 9 commits
a10644b
e83d90c
887af8d
d3f81f1
5a814c3
1e16543
3c68855
7f638b7
6b56265
fb77c39
7065d84
5e889b1
6bab4a3
7c58fad
91626af
7177ff4
db9cad2
2a3015c
fe2ccf6
7b197fa
05823f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
|
||
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT | ||
|
||
class Mnist(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) | ||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) | ||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) | ||
self.fc2 = torch.nn.Linear(500, 10) | ||
self.relu1 = torch.nn.ReLU6() | ||
self.relu2 = torch.nn.ReLU6() | ||
self.relu3 = torch.nn.ReLU6() | ||
self.max_pool1 = torch.nn.MaxPool2d(2, 2) | ||
self.max_pool2 = torch.nn.MaxPool2d(2, 2) | ||
|
||
def forward(self, x): | ||
x = self.relu1(self.conv1(x)) | ||
x = self.max_pool1(x) | ||
x = self.relu2(self.conv2(x)) | ||
x = self.max_pool2(x) | ||
x = x.view(-1, 4 * 4 * 50) | ||
x = self.relu3(self.fc1(x)) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
def train(model, device, train_loader, optimizer): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % 100 == 0: | ||
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) | ||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
test_loss /= len(test_loader.dataset) | ||
|
||
print('Loss: {} Accuracy: {}%)\n'.format( | ||
test_loss, 100 * correct / len(test_loader.dataset))) | ||
|
||
def test_trt(engine, test_loader): | ||
test_loss = 0 | ||
correct = 0 | ||
time_elasped = 0 | ||
for data, target in test_loader: | ||
output, time = engine.inference(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
time_elasped += time | ||
test_loss /= len(test_loader.dataset) | ||
|
||
print('Loss: {} Accuracy: {}%'.format( | ||
test_loss, 100 * correct / len(test_loader.dataset))) | ||
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped)) | ||
|
||
def main(): | ||
torch.manual_seed(0) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
train_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=True, download=True, transform=trans), | ||
batch_size=64, shuffle=True) | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=False, transform=trans), | ||
batch_size=1000, shuffle=True) | ||
|
||
model = Mnist() | ||
|
||
config = { | ||
'conv1':{'weight_bit':8, 'activation_bit':8}, | ||
'conv2':{'weight_bit':32, 'activation_bit':32}, | ||
'fc1':{'weight_bit':16, 'activation_bit':16}, | ||
'fc2':{'weight_bit':8, 'activation_bit':8} | ||
} | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) | ||
|
||
model.to(device) | ||
for epoch in range(1): | ||
print('# Epoch {} #'.format(epoch)) | ||
train(model, device, train_loader, optimizer) | ||
test(model, device, test_loader) | ||
|
||
batch_size = 32 | ||
input_shape = (batch_size, 1, 28, 28) | ||
|
||
engine = ModelSpeedupTensorRT(model, input_shape, config=config, calib_data_loader=train_loader, batchsize=batch_size) | ||
engine.compress() | ||
test_trt(engine, test_loader) | ||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torchvision import datasets, transforms | ||
|
||
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer | ||
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT | ||
|
||
class Mnist(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) | ||
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) | ||
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) | ||
self.fc2 = torch.nn.Linear(500, 10) | ||
self.relu1 = torch.nn.ReLU6() | ||
self.relu2 = torch.nn.ReLU6() | ||
self.relu3 = torch.nn.ReLU6() | ||
self.maxpool1 = torch.nn.MaxPool2d(2, 2) | ||
self.maxpool2 = torch.nn.MaxPool2d(2, 2) | ||
|
||
def forward(self, x): | ||
x = self.relu1(self.conv1(x)) | ||
x = self.maxpool1(x) | ||
x = self.relu2(self.conv2(x)) | ||
x = self.maxpool2(x) | ||
x = x.view(-1, 4 * 4 * 50) | ||
x = self.relu3(self.fc1(x)) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
def train(model, device, train_loader, optimizer): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % 100 == 0: | ||
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) | ||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
test_loss /= len(test_loader.dataset) | ||
|
||
print('Loss: {} Accuracy: {}%)\n'.format( | ||
test_loss, 100 * correct / len(test_loader.dataset))) | ||
|
||
def test_trt(engine, test_loader): | ||
test_loss = 0 | ||
correct = 0 | ||
time_elasped = 0 | ||
for data, target in test_loader: | ||
output, time = engine.inference(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
time_elasped += time | ||
test_loss /= len(test_loader.dataset) | ||
|
||
print('Loss: {} Accuracy: {}%'.format( | ||
test_loss, 100 * correct / len(test_loader.dataset))) | ||
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped)) | ||
|
||
def main(): | ||
torch.manual_seed(0) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
train_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=True, download=True, transform=trans), | ||
batch_size=64, shuffle=True) | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST('data', train=False, transform=trans), | ||
batch_size=1000, shuffle=True) | ||
|
||
model = Mnist() | ||
|
||
configure_list = [{ | ||
'quant_types': ['weight', 'output'], | ||
'quant_bits': {'weight':8, 'output':8}, | ||
'op_names': ['conv1'] | ||
}, { | ||
'quant_types': ['output'], | ||
'quant_bits': {'output':8}, | ||
'op_names': ['relu1'] | ||
}, { | ||
'quant_types': ['weight', 'output'], | ||
'quant_bits': {'weight':8, 'output':8}, | ||
'op_names': ['conv2'] | ||
}, { | ||
'quant_types': ['output'], | ||
'quant_bits': {'output':8}, | ||
'op_names': ['relu2'] | ||
} | ||
] | ||
|
||
# finetune the model by using QAT | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) | ||
quantizer = QAT_Quantizer(model, configure_list, optimizer) | ||
quantizer.compress() | ||
|
||
model.to(device) | ||
for epoch in range(1): | ||
print('# Epoch {} #'.format(epoch)) | ||
train(model, device, train_loader, optimizer) | ||
test(model, device, test_loader) | ||
|
||
model_path = "mnist_model.pth" | ||
calibration_path = "mnist_calibration.pth" | ||
calibration_config = quantizer.export_model(model_path, calibration_path) | ||
|
||
test(model, device, test_loader) | ||
|
||
print("calibration_config: ", calibration_config) | ||
|
||
batch_size = 32 | ||
input_shape = (batch_size, 1, 28, 28) | ||
|
||
engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size) | ||
engine.compress() | ||
|
||
test_trt(engine, test_loader) | ||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ def validate_config(self, model, config_list): | |
|
||
schema.validate(config_list) | ||
|
||
def quantize_weight(self, wrapper, **kwargs): | ||
def quantize_weight(self, input, wrapper, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the meaning of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this new version, input tensor's dynamic range will also be recorded to meet the requirement of tensorrt tensor range setting. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so what is the meaning of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
weight = copy.deepcopy(wrapper.module.old_weight.data) | ||
new_scale = weight.abs().max() / 127 | ||
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale) | ||
|
@@ -152,22 +152,23 @@ def __init__(self, model, config_list, optimizer=None): | |
for layer, config in modules_to_compress: | ||
layer.module.register_buffer("zero_point", torch.Tensor([0.0])) | ||
layer.module.register_buffer("scale", torch.Tensor([1.0])) | ||
layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) | ||
if "weight" in config.get("quant_types", []): | ||
layer.module.register_buffer('weight_bit', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_min_input', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_max_input', torch.zeros(1)) | ||
if "output" in config.get("quant_types", []): | ||
layer.module.register_buffer('activation_bit', torch.zeros(1)) | ||
layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) | ||
layer.module.register_buffer('tracked_min_biased', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_min', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_max_biased', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_max', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_min_activation', torch.zeros(1)) | ||
layer.module.register_buffer('tracked_max_activation', torch.zeros(1)) | ||
|
||
|
||
def _del_simulated_attr(self, module): | ||
""" | ||
delete redundant parameters in quantize module | ||
""" | ||
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_biased', 'tracked_max_biased', 'tracked_min', \ | ||
'tracked_max', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] | ||
del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \ | ||
'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit'] | ||
for attr in del_attr_list: | ||
if hasattr(module, attr): | ||
delattr(module, attr) | ||
|
@@ -240,7 +241,7 @@ def _dequantize(self, op, quantized_val): | |
real_val = op.scale * (quantized_val - op.zero_point) | ||
return real_val | ||
|
||
def quantize_weight(self, wrapper, **kwargs): | ||
def quantize_weight(self, input, wrapper, **kwargs): | ||
config = wrapper.config | ||
module = wrapper.module | ||
weight = copy.deepcopy(wrapper.module.old_weight.data) | ||
|
@@ -250,8 +251,16 @@ def quantize_weight(self, wrapper, **kwargs): | |
|
||
# we dont update weight in evaluation stage | ||
if quant_start_step > self.bound_model.steps or not wrapper.training: | ||
module.tracked_min_input, module.tracked_max_input = torch.min(input), torch.max(input) | ||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return weight | ||
|
||
if wrapper.training: | ||
current_min, current_max = torch.min(input), torch.max(input) | ||
module.tracked_min_input = update_ema(module.tracked_min_input, current_min, | ||
module.ema_decay) | ||
module.tracked_max_input = update_ema(module.tracked_max_input, current_max, | ||
module.ema_decay) | ||
|
||
# if bias exists, quantize bias to uint32 | ||
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None: | ||
bias = wrapper.module.bias.data | ||
|
@@ -281,17 +290,17 @@ def quantize_output(self, output, wrapper, **kwargs): | |
assert output_bits >= 1, "quant bits length should be at least 1" | ||
|
||
if quant_start_step > self.bound_model.steps: | ||
module.tracked_min_biased, module.tracked_max_biased = torch.min(output), torch.max(output) | ||
module.tracked_min_activation, module.tracked_max_activation = torch.min(output), torch.max(output) | ||
return output | ||
|
||
# we dont update output quantization parameters in evaluation stage | ||
if wrapper.training: | ||
current_min, current_max = torch.min(output), torch.max(output) | ||
module.tracked_min_biased = update_ema(module.tracked_min_biased, current_min, | ||
module.tracked_min_activation = update_ema(module.tracked_min_activation, current_min, | ||
module.ema_decay) | ||
module.tracked_max_biased = update_ema(module.tracked_max_biased, current_max, | ||
module.tracked_max_activation = update_ema(module.tracked_max_activation, current_max, | ||
module.ema_decay) | ||
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_biased, module.tracked_max_biased) | ||
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_activation, module.tracked_max_activation) | ||
out = self._quantize(output_bits, module, output) | ||
out = self._dequantize(module, out) | ||
return out | ||
|
@@ -327,10 +336,12 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ | |
calibration_config[name] = {} | ||
if hasattr(module, 'weight_bit'): | ||
calibration_config[name]['weight_bit'] = int(module.weight_bit) | ||
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input) | ||
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input) | ||
if hasattr(module, 'activation_bit'): | ||
calibration_config[name]['activation_bit'] = int(module.activation_bit) | ||
calibration_config[name]['tracked_min'] = float(module.tracked_min_biased) | ||
calibration_config[name]['tracked_max'] = float(module.tracked_max_biased) | ||
calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation) | ||
calibration_config[name]['tracked_max_activation'] = float(module.tracked_max_activation) | ||
self._del_simulated_attr(module) | ||
|
||
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) | ||
|
@@ -390,7 +401,7 @@ def validate_config(self, model, config_list): | |
|
||
schema.validate(config_list) | ||
|
||
def quantize_weight(self, wrapper, **kwargs): | ||
def quantize_weight(self, input, wrapper, **kwargs): | ||
weight = copy.deepcopy(wrapper.module.old_weight.data) | ||
weight_bits = get_bits_length(wrapper.config, 'weight') | ||
weight = weight.tanh() | ||
|
@@ -496,7 +507,7 @@ def validate_config(self, model, config_list): | |
|
||
schema.validate(config_list) | ||
|
||
def quantize_weight(self, wrapper, **kwargs): | ||
def quantize_weight(self, input, wrapper, **kwargs): | ||
weight = copy.deepcopy(wrapper.module.old_weight.data) | ||
weight = torch.sign(weight) | ||
# remove zeros | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -483,7 +483,7 @@ def forward(self, *inputs): | |
self.quantizer.quant_grad.apply( | ||
self.module.old_weight, | ||
QuantType.QUANT_WEIGHT, | ||
self) | ||
self, inputs[0]) | ||
result = self.module(*inputs) | ||
else: | ||
result = self.module(*inputs) | ||
|
@@ -720,11 +720,11 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma | |
return grad_output | ||
|
||
@staticmethod | ||
def forward(ctx, tensor, quant_type, wrapper, **kwargs): | ||
def forward(ctx, tensor, quant_type, wrapper, tensor_alt=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the meaning of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is used to transfer the second tensor during the forward process. In previous implementation, we only transfer one tensor like weight, input, output. But in some situation, we need to transfer two tensors to calibrate both of them like the function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The argument |
||
if quant_type == QuantType.QUANT_INPUT: | ||
output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) | ||
elif quant_type == QuantType.QUANT_WEIGHT: | ||
output = wrapper.quantizer.quantize_weight(wrapper, **kwargs) | ||
output = wrapper.quantizer.quantize_weight(tensor_alt, wrapper, **kwargs) | ||
elif quant_type == QuantType.QUANT_OUTPUT: | ||
output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems this example includes the above example, so why need two examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In example
mixed_precision_speedup_mnist.py
, model will be quantized in tensorrt directly by providing calibration dataset and tensorrt will get quantization parameter by calibration process. We can consider it as post-training quantization.However, in example
mixed_precision_speedup_mnist_QAT.py
, we first finetune the model and get quantization parameters by using QAT algorithm. Then this model will be quantized in tensorrt without calibration dataset. We can consider it as Quantization aware training.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to put them into one example file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have put them into one file.