From 3700105c7fb4b8f6d963b53fd6598790821a5e66 Mon Sep 17 00:00:00 2001 From: v-linbin Date: Wed, 24 Mar 2021 20:27:00 +0800 Subject: [PATCH] Add doc for quantizer export_model() --- docs/en_US/Compression/QuickStart.rst | 5 ++--- .../model_compress/quantization/QAT_torch_quantizer.py | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/en_US/Compression/QuickStart.rst b/docs/en_US/Compression/QuickStart.rst index 9aa7181216..1113e5a753 100644 --- a/docs/en_US/Compression/QuickStart.rst +++ b/docs/en_US/Compression/QuickStart.rst @@ -110,12 +110,11 @@ Step2. Choose a quantizer and compress the model Step3. Export compression result ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -You can export the quantized model directly by using ``torch.save`` api and the quantized model can be loaded by ``torch.load`` without any extra modification. +After training and calibration, you can export model weight to a file, and the generated calibration parameters to a file as well. Exporting onnx model is also supported. .. code-block:: python - # Save quantized model which is generated by using NNI QAT algorithm - torch.save(model.state_dict(), "quantized_model.pth") + calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device) Plese refer to :githublink:`mnist example ` for example code. diff --git a/examples/model_compress/quantization/QAT_torch_quantizer.py b/examples/model_compress/quantization/QAT_torch_quantizer.py index 8dc58aa8ce..ef14ff5ce0 100644 --- a/examples/model_compress/quantization/QAT_torch_quantizer.py +++ b/examples/model_compress/quantization/QAT_torch_quantizer.py @@ -92,6 +92,14 @@ def main(): train(model, quantizer, device, train_loader, optimizer) test(model, device, test_loader) + model_path = "mnist_model.pth" + calibration_path = "mnist_calibration.pth" + onnx_path = "mnist_model.onnx" + input_shape = (1, 1, 28, 28) + device = torch.device("cuda") + + calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device) + print("Generated calibration config is: ", calibration_config) if __name__ == '__main__': main()