Skip to content

Commit

Permalink
Added nncf compression/quantization (#2052)
Browse files Browse the repository at this point in the history
* Added nncf compression/quantization

Signed-off-by: Adrian Boguszewski <adrian.boguszewski@intel.com>

* Docs changes

Signed-off-by: Adrian Boguszewski <adrian.boguszewski@intel.com>

* Minor changes

Signed-off-by: Adrian Boguszewski <adrian.boguszewski@intel.com>

* Changes according to review

Signed-off-by: Adrian Boguszewski <adrian.boguszewski@intel.com>

* Fixes in CLI

Signed-off-by: Adrian Boguszewski <adrian.boguszewski@intel.com>

---------

Signed-off-by: Adrian Boguszewski <adrian.boguszewski@intel.com>
  • Loading branch information
adrianboguszewski authored May 17, 2024
1 parent 28e023e commit 4d0beb0
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

- 🚀 Update OpenVINO and ONNX export to support fixed input shape by @adrianboguszewski in https://github.com/openvinotoolkit/anomalib/pull/2006
- Add data_path argument to predict entrypoint and add properties for retrieving model path by @djdameln in https://github.com/openvinotoolkit/anomalib/pull/2018
- 🚀 Add compression and quantization for OpenVINO export by @adrianboguszewski in https://github.com/openvinotoolkit/anomalib/pull/2052

### Changed

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Anomalib is a deep learning library that aims to collect state-of-the-art anomal
- Simple and modular API and CLI for training, inference, benchmarking, and hyperparameter optimization.
- The largest public collection of ready-to-use deep learning anomaly detection algorithms and benchmark datasets.
- [**Lightning**](https://www.lightning.ai/) based model implementations to reduce boilerplate code and limit the implementation efforts to the bare essentials.
- All models can be exported to [**OpenVINO**](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html) Intermediate Representation (IR) for accelerated inference on intel hardware.
- The majority of models can be exported to [**OpenVINO**](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html) Intermediate Representation (IR) for accelerated inference on Intel hardware.
- A set of [inference tools](tools) for quick and easy deployment of the standard or custom anomaly detection models.

# 📦 Installation
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Anomalib is a deep learning library that aims to collect state-of-the-art anomal
- Simple and modular API and CLI for training, inference, benchmarking, and hyperparameter optimization.
- The largest public collection of ready-to-use deep learning anomaly detection algorithms and benchmark datasets.
- Lightning based model implementations to reduce boilerplate code and limit the implementation efforts to the bare essentials.
- All models can be exported to OpenVINO Intermediate Representation (IR) for accelerated inference on intel hardware.
- The majority of models can be exported to OpenVINO Intermediate Representation (IR) for accelerated inference on Intel hardware.
- A set of inference tools for quick and easy deployment of the standard or custom anomaly detection models.
:::

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ core = [
"torchmetrics>=1.3.2",
"open-clip-torch>=2.23.0",
]
openvino = ["openvino-dev>=2023.1", "nncf>=2.6.0", "onnx>=1.16.0"]
openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"]
loggers = [
"comet-ml>=3.31.7",
"gradio>=4",
Expand Down
7 changes: 6 additions & 1 deletion src/anomalib/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,15 @@ def add_export_arguments(self, parser: ArgumentParser) -> None:
fail_untyped=False,
required=True,
)
parser.add_argument(
"--data",
type=AnomalibDataModule,
required=False,
)
added = parser.add_method_arguments(
Engine,
"export",
skip={"ov_args", "model"},
skip={"ov_args", "model", "datamodule"},
)
self.subcommand_method_arguments["export"] = added
add_openvino_export_arguments(parser)
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/deploy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .export import ExportType
from .export import CompressionType, ExportType
from .inferencers import Inferencer, OpenVINOInferencer, TorchInferencer

__all__ = ["Inferencer", "OpenVINOInferencer", "TorchInferencer", "ExportType"]
__all__ = ["Inferencer", "OpenVINOInferencer", "TorchInferencer", "ExportType", "CompressionType"]
26 changes: 26 additions & 0 deletions src/anomalib/deploy/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,32 @@ class ExportType(str, Enum):
TORCH = "torch"


class CompressionType(str, Enum):
"""Model compression type when exporting to OpenVINO.
Examples:
>>> from anomalib.deploy import CompressionType
>>> CompressionType.INT8_PTQ
'int8_ptq'
"""

FP16 = "fp16"
"""
Weight compression (FP16)
All weights are converted to FP16.
"""
INT8 = "int8"
"""
Weight compression (INT8)
All weights are quantized to INT8, but are dequantized to floating point before inference.
"""
INT8_PTQ = "int8_ptq"
"""
Full integer post-training quantization (INT8)
All weights and operations are quantized to INT8. Inference is done in INT8 precision.
"""


class InferenceModel(nn.Module):
"""Inference model for export.
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/deploy/inferencers/openvino_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
logger = logging.getLogger("anomalib")

if find_spec("openvino") is not None:
import openvino.runtime as ov
import openvino as ov

if TYPE_CHECKING:
from openvino.runtime import CompiledModel
from openvino import CompiledModel
else:
logger.warning("OpenVINO is not installed. Please install OpenVINO to use OpenVINOInferencer.")

Expand Down
13 changes: 11 additions & 2 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from anomalib.callbacks.timer import TimerCallback
from anomalib.callbacks.visualizer import _VisualizationCallback
from anomalib.data import AnomalibDataModule, AnomalibDataset, PredictDataset
from anomalib.deploy import ExportType
from anomalib.deploy import CompressionType, ExportType
from anomalib.models import AnomalyModule
from anomalib.utils.normalization import NormalizationMethod
from anomalib.utils.path import create_versioned_dir
Expand Down Expand Up @@ -869,6 +869,8 @@ def export(
export_root: str | Path | None = None,
input_size: tuple[int, int] | None = None,
transform: Transform | None = None,
compression_type: CompressionType | None = None,
datamodule: AnomalibDataModule | None = None,
ov_args: dict[str, Any] | None = None,
ckpt_path: str | Path | None = None,
) -> Path | None:
Expand All @@ -884,6 +886,11 @@ def export(
transform (Transform | None, optional): Input transform to include in the exported model. If not provided,
the engine will try to use the default transform from the model.
Defaults to ``None``.
compression_type (CompressionType | None, optional): Compression type for OpenVINO exporting only.
Defaults to ``None``.
datamodule (AnomalibDataModule | None, optional): Lightning datamodule.
Must be provided if CompressionType.INT8_PTQ is selected.
Defaults to ``None``.
ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer.
Defaults to None.
ckpt_path (str | Path | None): Checkpoint path. If provided, the model will be loaded from this path.
Expand All @@ -910,7 +917,7 @@ def export(
anomalib export --model Padim --export_mode openvino --ckpt_path <PATH_TO_CHECKPOINT> \
--input_size "[256,256]"
```
4. You can also overrride OpenVINO model optimizer by adding the ``--ov_args.<key>`` arguments.
4. You can also override OpenVINO model optimizer by adding the ``--ov_args.<key>`` arguments.
```python
anomalib export --model Padim --export_mode openvino --ckpt_path <PATH_TO_CHECKPOINT> \
--input_size "[256,256]" --ov_args.compress_to_fp16 False
Expand Down Expand Up @@ -945,6 +952,8 @@ def export(
input_size=input_size,
transform=transform,
task=self.task,
compression_type=compression_type,
datamodule=datamodule,
ov_args=ov_args,
)
else:
Expand Down
35 changes: 31 additions & 4 deletions src/anomalib/models/components/base/export_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from torchvision.transforms.v2 import Transform

from anomalib import TaskType
from anomalib.deploy.export import ExportType, InferenceModel
from anomalib.data import AnomalibDataModule
from anomalib.deploy.export import CompressionType, ExportType, InferenceModel
from anomalib.utils.exceptions import try_import

if TYPE_CHECKING:
Expand Down Expand Up @@ -156,6 +157,8 @@ def to_openvino(
export_root: Path | str,
input_size: tuple[int, int] | None = None,
transform: Transform | None = None,
compression_type: CompressionType | None = None,
datamodule: AnomalibDataModule | None = None,
ov_args: dict[str, Any] | None = None,
task: TaskType | None = None,
) -> Path:
Expand All @@ -168,7 +171,12 @@ def to_openvino(
transform (Transform, optional): Input transforms used for the model. If not provided, the transform is
taken from the model.
Defaults to ``None``.
ov_args: Model optimizer arguments for OpenVINO model conversion.
compression_type (CompressionType, optional): Compression type for better inference performance.
Defaults to ``None``.
datamodule (AnomalibDataModule | None, optional): Lightning datamodule.
Must be provided if CompressionType.INT8_PTQ is selected.
Defaults to ``None``.
ov_args (dict | None): Model optimizer arguments for OpenVINO model conversion.
Defaults to ``None``.
task (TaskType | None): Task type.
Defaults to ``None``.
Expand Down Expand Up @@ -213,18 +221,37 @@ def to_openvino(
if not try_import("openvino"):
logger.exception("Could not find OpenVINO. Please check OpenVINO installation.")
raise ModuleNotFoundError
if not try_import("nncf"):
logger.exception("Could not find NNCF. Please check NNCF installation.")
raise ModuleNotFoundError

import nncf
import openvino as ov

with TemporaryDirectory() as onnx_directory:
model_path = self.to_onnx(onnx_directory, input_size, transform, task)
export_root = _create_export_root(export_root, ExportType.OPENVINO)
ov_model_path = export_root / "model.xml"
ov_args = {} if ov_args is None else ov_args
# fp16 compression is enabled by default
compress_to_fp16 = ov_args.get("compress_to_fp16", True)

model = ov.convert_model(model_path, **ov_args)
if compression_type == CompressionType.INT8:
model = nncf.compress_weights(model)
elif compression_type == CompressionType.INT8_PTQ:
if datamodule is None:
msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression"
raise ValueError(msg)

dataloader = datamodule.val_dataloader()
if len(dataloader.dataset) < 300:
logger.warning(
f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images",
)
calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"])
model = nncf.quantize(model, calibration_dataset)

# fp16 compression is enabled by default
compress_to_fp16 = compression_type == CompressionType.FP16
ov.save_model(model, ov_model_path, compress_to_fp16=compress_to_fp16)
_write_metadata_to_json(self._get_metadata(task), export_root)

Expand Down

0 comments on commit 4d0beb0

Please sign in to comment.