Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove config from torch inferencer #1001

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
501484c
Fix metadata path
samet-akcay Mar 10, 2023
15ad7e9
Merge branch 'main' of github.com:openvinotoolkit/anomalib
samet-akcay Mar 20, 2023
a6bc69f
Merge branch 'main' of github.com:openvinotoolkit/anomalib
samet-akcay Mar 22, 2023
7b53b78
Merge branch 'main' of github.com:openvinotoolkit/anomalib
samet-akcay Mar 24, 2023
af2022f
Merge branch 'main' of github.com:openvinotoolkit/anomalib
samet-akcay Mar 28, 2023
e4aadca
Merge branch 'main' of github.com:openvinotoolkit/anomalib
samet-akcay Mar 31, 2023
c61472e
Merge branch 'main' of github.com:openvinotoolkit/anomalib
samet-akcay Apr 4, 2023
36fa3b3
Add torch to config options
samet-akcay Apr 5, 2023
4bb0666
Add torch to config options
samet-akcay Apr 5, 2023
d8b1b57
Save the lightning model in weights/lightning/model.ckpt
samet-akcay Apr 5, 2023
6df8d04
Add export_to_torch support
samet-akcay Apr 5, 2023
067e697
removed config from torch inferencer
samet-akcay Apr 5, 2023
cfdd5df
Merge branch 'main' into remove-config-from-torch-inferencer
samet-akcay Apr 5, 2023
7d94cec
addressed pr comments
samet-akcay Apr 11, 2023
474c3b3
Merge branch 'remove-config-from-torch-inferencer' of github.com:same…
samet-akcay Apr 11, 2023
89a6a6a
Merge branch 'main' into remove-config-from-torch-inferencer
samet-akcay Apr 11, 2023
e2396fb
Remove anomalymodule from torch inference docstring
samet-akcay Apr 11, 2023
b872a47
Modify benchmark.py
samet-akcay Apr 11, 2023
c1f5063
Modify sweep inference helpers
samet-akcay Apr 11, 2023
aa4c8b2
Fix tests
samet-akcay Apr 12, 2023
288f4b4
Merge branch 'remove-config-from-torch-inferencer' of github.com:same…
samet-akcay Apr 12, 2023
f03a366
Fix export tests
samet-akcay Apr 13, 2023
4c6bebb
Fix notebooks
samet-akcay Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ Example OpenVINO Inference:

```bash
python tools/inference/openvino_inference.py \
--config src/anomalib/models/padim/config.yaml \
--weights results/padim/mvtec/bottle/run/openvino/model.bin \
--metadata results/padim/mvtec/bottle/run/openvino/metadata.json \
--input datasets/MVTec/bottle/test/broken_large/000.png \
Expand All @@ -207,7 +206,6 @@ A quick example:

```bash
python tools/inference/gradio_inference.py \
--config src/anomalib/models/padim/config.yaml \
--weights results/padim/mvtec/bottle/run/weights/model.ckpt
```

Expand Down
4 changes: 2 additions & 2 deletions notebooks/000_getting_started/001_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@
}
],
"source": [
"openvino_model_path = output_path / \"openvino\" / \"model.bin\"\n",
"metadata_path = output_path / \"openvino\" / \"metadata.json\"\n",
"openvino_model_path = output_path / \"weights\" / \"openvino\" / \"model.bin\"\n",
"metadata_path = output_path / \"weights\" / \"openvino\" / \"metadata.json\"\n",
"print(openvino_model_path.exists(), metadata_path.exists())"
]
},
Expand Down
4 changes: 2 additions & 2 deletions notebooks/400_openvino/401_nncf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
"\n",
"```yaml\n",
"optimization:\n",
" export_mode: null #options: onnx, openvino\n",
" export_mode: null # options: torch, onnx, openvino\n",
" nncf:\n",
" apply: true\n",
" input_info:\n",
Expand Down Expand Up @@ -282,7 +282,7 @@
"\n",
"```yaml\n",
"optimization:\n",
" export_mode: null #options: onnx, openvino\n",
" export_mode: null # options: torch, onnx, openvino\n",
" nncf:\n",
" apply: true\n",
" input_info:\n",
Expand Down
50 changes: 38 additions & 12 deletions src/anomalib/deploy/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ExportMode(str, Enum):

ONNX = "onnx"
OPENVINO = "openvino"
TORCH = "torch"


def get_model_metadata(model: AnomalyModule) -> dict[str, Tensor]:
Expand Down Expand Up @@ -59,6 +60,7 @@ def get_metadata(task: TaskType, transform: dict[str, Any], model: AnomalyModule
task (TaskType): Task type.
transform (dict[str, Any]): Transform used for the model.
model (AnomalyModule): Model to export.
export_mode (ExportMode): Mode to export the model. Torch, ONNX or OpenVINO.

Returns:
dict[str, Any]: Metadata for the exported model.
Expand Down Expand Up @@ -90,20 +92,44 @@ def export(
transform (dict[str, Any]): Data transforms (augmentatiions) used for the model.
input_size (tuple[int, int]): Input size of the model.
model (AnomalyModule): Anomaly model to export.
export_mode (ExportMode): Mode to export the model. ONNX or OpenVINO.
export_root (str | Path): Path to exported ONNX/OpenVINO IR.
export_mode (ExportMode): Mode to export the model. Torch, ONNX or OpenVINO.
export_root (str | Path): Path to exported Torch, ONNX or OpenVINO IR.
"""
# Write metadata to json file. The file is written in the same directory as the target model.
export_path = Path(export_root) / export_mode.value
# Create export directory.
export_path = Path(export_root) / "weights" / export_mode.value
export_path.mkdir(parents=True, exist_ok=True)
with (Path(export_path) / "metadata.json").open("w", encoding="utf-8") as metadata_file:
metadata = get_metadata(task, transform, model)
json.dump(metadata, metadata_file, ensure_ascii=False, indent=4)

# Export model to onnx and convert to OpenVINO IR if export mode is set to OpenVINO.
onnx_path = export_to_onnx(model, input_size, export_path)
if export_mode == ExportMode.OPENVINO:
export_to_openvino(export_path, onnx_path)

# Get metadata.
metadata = get_metadata(task, transform, model)

if export_mode == ExportMode.TORCH:
export_to_torch(model, metadata, export_path)

elif export_mode in (ExportMode.ONNX, ExportMode.OPENVINO):
# Write metadata to json file. The file is written in the same directory as the target model.
with (Path(export_path) / "metadata.json").open("w", encoding="utf-8") as metadata_file:
json.dump(metadata, metadata_file, ensure_ascii=False, indent=4)

# Export model to onnx and convert to OpenVINO IR if export mode is set to OpenVINO.
onnx_path = export_to_onnx(model, input_size, export_path)
if export_mode == ExportMode.OPENVINO:
export_to_openvino(export_path, onnx_path)

else:
raise ValueError(f"Unknown export mode {export_mode}")


def export_to_torch(model: AnomalyModule, metadata: dict[str, Any], export_path: Path) -> None:
"""Export AnomalibModel to torch.

Args:
model (AnomalyModule): Model to export.
export_path (Path): Path to the folder storing the exported model.
"""
torch.save(
obj={"model": model.model, "metadata": metadata},
f=export_path / "model.pt",
)


def export_to_onnx(model: AnomalyModule, input_size: tuple[int, int], export_path: Path) -> Path:
Expand Down
75 changes: 18 additions & 57 deletions src/anomalib/deploy/inferencers/torch_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,15 @@
from pathlib import Path
from typing import Any

import albumentations as A
import cv2
import numpy as np
import torch
from omegaconf import DictConfig, ListConfig
from torch import Tensor
from omegaconf import DictConfig
from torch import Tensor, nn

from anomalib.config import get_configurable_parameters
from anomalib.data import TaskType
from anomalib.data.utils import InputNormalizationMethod, get_transforms
from anomalib.data.utils.boxes import masks_to_boxes
from anomalib.deploy.export import get_model_metadata
from anomalib.models import get_model
from anomalib.models.components import AnomalyModule

from .base_inferencer import Inferencer

Expand All @@ -29,38 +25,21 @@ class TorchInferencer(Inferencer):
"""PyTorch implementation for the inference.

Args:
config (str | Path | DictConfig | ListConfig): Configurable parameters that are used
during the training stage.
model_source (str | Path | AnomalyModule): Path to the model ckpt file or the Anomaly model.
metadata_path (str | Path, optional): Path to metadata file. If none, it tries to load the params
from the model state_dict. Defaults to None.
device (str | None, optional): Device to use for inference. Options are auto, cpu, cuda. Defaults to "auto".
path (str | Path): Path to Torch model weights.
device (str): Device to use for inference. Options are auto, cpu, cuda. Defaults to "auto".
"""

def __init__(
self,
config: str | Path | DictConfig | ListConfig,
model_source: str | Path | AnomalyModule,
metadata_path: str | Path | None = None,
path: str | Path,
device: str = "auto",
) -> None:
self.device = self._get_device(device)

# Check and load the configuration
if isinstance(config, (str, Path)):
self.config = get_configurable_parameters(config_path=config)
elif isinstance(config, (DictConfig, ListConfig)):
self.config = config
else:
raise ValueError(f"Unknown config type {type(config)}")

# Check and load the model weights.
if isinstance(model_source, AnomalyModule):
self.model = model_source
else:
self.model = self.load_model(model_source)

self.metadata = self._load_metadata(metadata_path)
# Load the model weights.
self.model = self.load_model(path)
self.metadata = self._load_metadata(path)
self.transform = A.from_dict(self.metadata["transform"])

@staticmethod
def _get_device(device: str) -> torch.device:
Expand All @@ -82,24 +61,18 @@ def _get_device(device: str) -> torch.device:
return torch.device(device)

def _load_metadata(self, path: str | Path | None = None) -> dict | DictConfig:
"""Load metadata from file or from model state dict.
"""Load metadata from file.

Args:
path (str | Path | None, optional): Path to metadata file. If none, it tries to load the params
from the model state_dict. Defaults to None.
path (str | Path): Path to the model pt file.

Returns:
dict: Dictionary containing the metadata.
"""
metadata: dict[str, float | np.ndarray | Tensor] | DictConfig
if path is None:
# Torch inferencer still reads metadata from the model.
metadata = get_model_metadata(self.model)
else:
metadata = super()._load_metadata(path)
metadata = torch.load(path, map_location=self.device)["metadata"] if path else {}
return metadata

def load_model(self, path: str | Path) -> AnomalyModule:
def load_model(self, path: str | Path) -> nn.Module:
"""Load the PyTorch model.

Args:
Expand All @@ -108,8 +81,8 @@ def load_model(self, path: str | Path) -> AnomalyModule:
Returns:
(AnomalyModule): PyTorch Lightning model.
"""
model = get_model(self.config)
model.load_state_dict(torch.load(path, map_location=self.device)["state_dict"])

model = torch.load(path, map_location=self.device)["model"]
model.eval()
return model.to(self.device)

Expand All @@ -122,19 +95,7 @@ def pre_process(self, image: np.ndarray) -> Tensor:
Returns:
Tensor: pre-processed image.
"""
transform_config = (
self.config.dataset.transform_config.eval if "transform_config" in self.config.dataset.keys() else None
)

image_size = (self.config.dataset.image_size[0], self.config.dataset.image_size[1])
center_crop = self.config.dataset.get("center_crop")
if center_crop is not None:
center_crop = tuple(center_crop)
normalization = InputNormalizationMethod(self.config.dataset.normalization)
transform = get_transforms(
config=transform_config, image_size=image_size, center_crop=center_crop, normalization=normalization
)
processed_image = transform(image=image)["image"]
processed_image = self.transform(image=image)["image"]

if len(processed_image) == 3:
processed_image = processed_image.unsqueeze(0)
Expand Down Expand Up @@ -209,7 +170,7 @@ def post_process(self, predictions: Tensor, metadata: dict | DictConfig | None =
if pred_mask is not None:
pred_mask = cv2.resize(pred_mask, (image_width, image_height))

if self.config.dataset.task == TaskType.DETECTION:
if self.metadata["task"] == TaskType.DETECTION:
pred_boxes = masks_to_boxes(torch.from_numpy(pred_mask))[0][0].numpy()
box_labels = np.ones(pred_boxes.shape[0])
else:
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/cfa/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ logging:
log_graph: false # Logs the model graph to respective logger.

optimization:
export_mode: null #options: onnx, openvino
export_mode: null # options: torch, onnx, openvino

# PL Trainer Args. Don't add extra parameter here.
trainer:
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/cflow/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ logging:
log_graph: false # Logs the model graph to respective logger.

optimization:
export_mode: null #options: onnx, openvino
export_mode: null # options: torch, onnx, openvino

# PL Trainer Args. Don't add extra parameter here.
trainer:
Expand Down
3 changes: 2 additions & 1 deletion src/anomalib/models/cflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import numpy as np
import torch
from FrEIA.framework import SequenceINN
from FrEIA.modules import AllInOneBlock
from torch import Tensor, nn

from anomalib.models.components.flow import AllInOneBlock

logger = logging.getLogger(__name__)


Expand Down
8 changes: 8 additions & 0 deletions src/anomalib/models/components/flow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""All In One Block Layer."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .all_in_one_block import AllInOneBlock

__all__ = ["AllInOneBlock"]
Loading