Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add LSQ quantizer (#3503)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenbohua3 authored May 18, 2021
1 parent 761732a commit af929fd
Show file tree
Hide file tree
Showing 5 changed files with 435 additions and 19 deletions.
2 changes: 2 additions & 0 deletions docs/en_US/Compression/Overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Quantization algorithms compress the original network by reducing the number of
- DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. `Reference Paper <https://arxiv.org/abs/1606.06160>`__
* - `BNN Quantizer <../Compression/Quantizer.rst#bnn-quantizer>`__
- Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1. `Reference Paper <https://arxiv.org/abs/1602.02830>`__
* - `LSQ Quantizer <../Compression/Quantizer.rst#lsq-quantizer>`__
- Learned step size quantization. `Reference Paper <https://arxiv.org/pdf/1902.08153.pdf>`__


Model Speedup
Expand Down
56 changes: 56 additions & 0 deletions docs/en_US/Compression/Quantizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Index of supported quantization algorithms
* `QAT Quantizer <#qat-quantizer>`__
* `DoReFa Quantizer <#dorefa-quantizer>`__
* `BNN Quantizer <#bnn-quantizer>`__
* `LSQ Quantizer <#lsq-quantizer>`__

Naive Quantizer
---------------
Expand Down Expand Up @@ -86,6 +87,61 @@ note

batch normalization folding is currently not supported.

----

LSQ Quantizer
-------------

In `LEARNED STEP SIZE QUANTIZATION <https://arxiv.org/pdf/1902.08153.pdf>`__\ , authors Steven K. Esser and Jeffrey L. McKinstry provide an algorithm to train the scales with gradients.

..
The authors introduce a novel means to estimate and scale the task loss gradient at each weight and activation layer’s quantizer step size, such that it can be learned in conjunction with other network parameters.


Usage
^^^^^
You can add codes below before your training codes. Three things must be done:


1. configure which layer to be quantized and which tensor (input/output/weight) of that layer to be quantized.
2. construct the lsq quantizer
3. call the `compress` API


PyTorch code

.. code-block:: python
from nni.algorithms.compression.pytorch.quantization import LsqQuantizer
model = Mnist()
configure_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': {
'weight': 8,
'input': 8,
},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8,},
'op_names': ['relu1']
}]
quantizer = LsqQuantizer(model, configure_list, optimizer)
quantizer.compress()
You can view example for more information. :githublink:`examples/model_compress/quantization/LSQ_torch_quantizer.py <examples/model_compress/quantization/LSQ_torch_quantizer.py>`

User configuration for LSQ Quantizer
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

common configuration needed by compression algorithms can be found at `Specification of `config_list <./QuickStart.rst>`__.

configuration needed by this algorithm :


----

DoReFa Quantizer
Expand Down
142 changes: 142 additions & 0 deletions examples/model_compress/quantization/LSQ_torch_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import LsqQuantizer
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT


class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
self.max_pool2 = torch.nn.MaxPool2d(2, 2)

def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.max_pool1(x)
x = self.relu2(self.conv2(x))
x = self.max_pool2(x)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(model, quantizer, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))


def test_trt(engine, test_loader):
test_loss = 0
correct = 0
time_elasped = 0
for data, target in test_loader:
output, time = engine.inference(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
time_elasped += time
test_loss /= len(test_loader.dataset)

print('Loss: {} Accuracy: {}%'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
print("Inference elapsed_time (whole dataset): {}s".format(time_elasped))


def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)

model = Mnist()
configure_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8, },
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_names': ['relu2']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_names': ['max_pool2']
}
]
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = LsqQuantizer(model, configure_list, optimizer)
quantizer.compress()

model.to(device)
for epoch in range(40):
print('# Epoch {} #'.format(epoch))
train(model, quantizer, device, train_loader, optimizer)
test(model, device, test_loader)

model_path = "mnist_model.pth"
calibration_path = "mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)

test(model, device, test_loader)

print("calibration_config: ", calibration_config)

batch_size = 32
input_shape = (batch_size, 1, 28, 28)

engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size)
engine.compress()

test_trt(engine, test_loader)


if __name__ == '__main__':
main()
Loading

0 comments on commit af929fd

Please sign in to comment.