Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add LSQ quantizer #3503

Merged
merged 18 commits into from
May 18, 2021
145 changes: 145 additions & 0 deletions examples/model_compress/quantization/LSQ_torch_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import LsqQuantizer
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT


class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
self.max_pool2 = torch.nn.MaxPool2d(2, 2)

def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.max_pool1(x)
x = self.relu2(self.conv2(x))
x = self.max_pool2(x)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(model, quantizer, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))


def test_trt(engine, test_loader):
test_loss = 0
correct = 0
time_elasped = 0
for data, target in test_loader:
output, time = engine.inference(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
time_elasped += time
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))


def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment can be removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

configure_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8, },
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_names': ['relu2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_names': ['max_pool2']
}
]
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = LsqQuantizer(model, configure_list, optimizer)
quantizer.compress()

model.to(device)
for epoch in range(40):
print('# Epoch {} #'.format(epoch))
train(model, quantizer, device, train_loader, optimizer)
test(model, device, test_loader)

model_path = "mnist_model.pth"
calibration_path = "mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)

test(model, device, test_loader)

print("calibration_config: ", calibration_config)

batch_size = 32
input_shape = (batch_size, 1, 28, 28)

engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size)
engine.compress()

test_trt(engine, test_loader)


if __name__ == '__main__':
main()
208 changes: 202 additions & 6 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.compressor import Quantizer, QuantGrad, QuantType
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType

__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,7 +59,7 @@ def update_ema(biased_ema, value, decay):
float, float
"""
biased_ema = biased_ema * decay + (1 - decay) * value
return biased_ema
return biased_ema


def update_quantization_param(bits, rmin, rmax):
Expand Down Expand Up @@ -146,7 +146,7 @@ def __init__(self, model, config_list, optimizer=None):
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list, optimizer)
self.quant_grad = QATGrad
self.quant_grad = QATGrad.apply
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we have to move apply here instead of using it directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it is for avoiding STE in LSQ quantizer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is aimed at unifying the framework of quantizers with customized gradient and quantizers with auto-grad gradient. Also, use.apply is the way recommended by PyTorch (see here)

modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress:
Expand Down Expand Up @@ -474,7 +474,7 @@ class BNNQuantizer(Quantizer):

def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
self.quant_grad = ClipGrad.apply
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
Expand Down Expand Up @@ -559,4 +559,200 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_

self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)

return calibration_config
return calibration_config


class LsqQuantizer(Quantizer):
Copy link
Contributor

@linbinskn linbinskn Apr 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstring as the other Quantizers, especially for parameters and return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Quantizer defined in:
Learned Step Size Quantization (ICLR 2020)
https://arxiv.org/pdf/1902.08153.pdf
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please align

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.Module
the model to be quantized
config_list : list of dict
list of configurations for quantization
supported keys for dict:
- quant_types : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
- quant_bits : int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{'weight', 8} -> {'weight': 8}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

when the type is int, all quantization types share same bits length
- quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list, optimizer)
self.quant_grad = QuantForward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we keep the original forward and backward structure, the Lsq can forward as usual and backward by STE. In this way, will it be anything wrong? May be have something to do with the update of scale and zeropoint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will not be anything wrong if the gradients are handled carefully. However, there exists one major limitation for the origin framework, that is, we must customize all gradients for all learnable parameters. If the gradient-based algorithms become complex, it will be troubling and error-prone to do the customization. In this situation, I think using the auto-grad system to determine the gradient is more convenient for users.

modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress:
if "weight" in config.get("quant_types", []):
layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0])))
# todo: support per-channel quantization for weight since TensorRT use it for conv weight
q_bit = get_bits_length(config, "weight")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current implementation, we only support single bit quantization in LsqQuantizer? Can we support mixed precision right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that mixed precision of quantization is supported in this implementation since each layer has its own q_bit. We can achieve mixed quantization through some specific settings in config_list like:

configure_list = [{
        'quant_types': ['weight'],
        'quant_bits': 8,
        'op_types': ['Conv2d'],
        'op_names': ['features.3']
    }, {
        'quant_types': ['weight'],
        'quant_bits': 7,
        'op_types': ['Conv2d'],
        'op_names': ['features.6']
    }]

layer.module.register_buffer('weight_bit', torch.Tensor([q_bit]))
qmax = 2 ** (q_bit - 1) - 1
qmin = -2 ** (q_bit - 1)
init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5)
layer.module.weight_scale = torch.nn.Parameter(init_weight_scale)
layer.module.weight_qmax = qmax
layer.module.weight_qmin = qmin

self.optimizer.add_param_group({"params": layer.module.weight_scale})

if "output" in config.get("quant_types", []):
# scale of activation will be initialized using the first batch data
layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0])))
q_bit = get_bits_length(config, "output")
layer.module.register_buffer('output_bit', torch.Tensor([q_bit]))
qmax = 2 ** (q_bit - 1) - 1
qmin = -2 ** (q_bit - 1)
layer.module.output_qmax = qmax
layer.module.output_qmin = qmin

self.optimizer.add_param_group({"params": layer.module.output_scale})

if "input" in config.get("quant_types", []):
# scale of activation will be initialized using the first batch data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activation -> input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0])))
q_bit = get_bits_length(config, "input")
layer.module.register_buffer('input_bit', torch.Tensor([q_bit]))
qmax = 2 ** (q_bit - 1) - 1
qmin = -2 ** (q_bit - 1)
layer.module.input_qmax = qmax
layer.module.input_qmin = qmin

self.optimizer.add_param_group({"params": layer.module.input_scale})

@staticmethod
def grad_scale(x, scale):
"""
Used to scale the gradient
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend explaining this function in detail since both of reviewers were confused during reviewing this part. Whatever, I think this function is also part of key implementation of LSQ which can helps others understand the insight of this algorithm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
y = x
y_grad = x * scale
return (y - y_grad).detach() + y_grad

@staticmethod
def round_pass(x):
"""
A simple way to execute `round` operation with grad set to 1
"""
y = x.round()
y_grad = x
return (y - y_grad).detach() + y_grad

def quantize(self, x, scale, qmin, qmax):
grad_scale_factor = 1.0 / ((qmax * x.numel()) ** 0.5)
scale = self.grad_scale(scale, grad_scale_factor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little confused about the name of value and function. Can we polish naming here or in grad_scale function? For instance, change the second parameter name 'scale' to 'scale_factor'.

Copy link
Contributor Author

@chenbohua3 chenbohua3 Apr 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names of functions and variables are the same as those defined in the paper.

x = x / scale
x = torch.clamp(x, qmin, qmax)
x = self.round_pass(x)
x = x * scale
return x

def quantize_weight(self, wrapper, **kwargs):
module = wrapper.module

# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
# bias
old_weight = module.old_weight
weight = self.quantize(old_weight, module.weight_scale, module.weight_qmin, module.weight_qmax)
module.weight = weight
return weight

def quantize_output(self, output, wrapper, **kwargs):
module = wrapper.module

# initialize the scale
if self.bound_model.steps == 1:
qmax = module.output_qmax
init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5)
module.output_scale.data = init_oup_scale

output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this quantization algorithm support exporting model and related quantization parameters? If yes, maybe we can consider adding function export_model() based on what parameters should export to inference framework like TensorRT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check it out


def quantize_input(self, *inputs, wrapper, **kwargs):
# This is hacky since it is not recommended to modify a tuple
# NB: support layers with multi inputs
module = wrapper.module
# initialize the scale
if self.bound_model.steps == 1:
qmax = module.input_qmax
init_oup_scale = inputs[0].data.detach().abs().mean() * 2 / (qmax ** 0.5)
module.input_scale.data = init_oup_scale

new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax)
list_inp = list(inputs)
list_inp[0] = new_input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we only quantize the first input

Copy link
Contributor Author

@chenbohua3 chenbohua3 May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that currently the quantization framework only supports layers with single input (see here, so is the trt backend, see here ). So current implementation does not support layers with multi inputs. It may be a better choice to modify the lsq quantizer to support layers with multi inputs after the framework supports it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, it is reasonable

return tuple(list_inp)

def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)

Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None

Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}

for name, module in self.bound_model.named_modules():
if hasattr(module, 'input_bit') or hasattr(module, 'output_bit'):
calibration_config[name] = {}
if hasattr(module, 'weight_bit'):
calibration_config[name]['weight_bit'] = int(module.weight_bit)
abs_max_input = float(module.input_scale * module.input_qmax)
calibration_config[name]['tracked_min_input'] = -abs_max_input
calibration_config[name]['tracked_max_input'] = abs_max_input
if hasattr(module, 'output_bit'):
calibration_config[name]['activation_bit'] = int(module.output_bit)
abs_max_output = float(module.output_scale * module.output_qmax)
calibration_config[name]['tracked_min_activation'] = -abs_max_output
calibration_config[name]['tracked_max_activation'] = abs_max_output
self._del_simulated_attr(module)

self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path,
input_shape, device)

return calibration_config

def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_activation', \
'tracked_max_activation', 'output_scale', 'input_scale', 'weight_scale','weight_bit', 'output_bit', 'input_bit']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)

def step_with_optimizer(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.bound_model.steps += 1
Loading