Skip to content

Commit

Permalink
Support static shapes when exporting to ONNX or OpenVINO
Browse files Browse the repository at this point in the history
Signed-off-by: Adrian Boguszewski <adrian.boguszewski@intel.com>
  • Loading branch information
adrianboguszewski committed Apr 17, 2024
1 parent 21c765c commit 69d0c5b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- 🚀 Update OpenVINO and ONNX export to support fixed input shape by @adrianboguszewski in https://github.com/openvinotoolkit/anomalib/pull/2006

### Changed

- 🔨Rename OptimalF1 to F1Max for consistency with the literature, by @samet-akcay in https://github.com/openvinotoolkit/anomalib/pull/1980
Expand Down
20 changes: 15 additions & 5 deletions src/anomalib/deploy/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def export_to_torch(
def export_to_onnx(
model: AnomalyModule,
export_root: Path | str,
input_size: tuple[int, int] | None = None,
transform: Transform | None = None,
task: TaskType | None = None,
export_type: ExportType = ExportType.ONNX,
Expand All @@ -169,6 +170,8 @@ def export_to_onnx(
Args:
model (AnomalyModule): Model to export.
export_root (Path): Path to the root folder of the exported model.
input_size (tuple[int, int] | None, optional): Image size used as the input for onnx converter.
Defaults to None.
transform (Transform, optional): Input transforms used for the model. If not provided, the transform is taken
from the model.
Defaults to ``None``.
Expand Down Expand Up @@ -212,14 +215,18 @@ def export_to_onnx(
transform = transform or model.transform or model.configure_transforms()
inference_model = InferenceModel(model=model.model, transform=transform, disable_antialias=True)
export_root = _create_export_root(export_root, export_type)
input_shape = torch.zeros((1, 3, *input_size)) if input_size else torch.zeros((1, 3, 1, 1))
dynamic_axes = (
None if input_size else {"input": {0: "batch_size", 2: "height", 3: "weight"}, "output": {0: "batch_size"}}
)
_write_metadata_to_json(export_root, model, task)
onnx_path = export_root / "model.onnx"
torch.onnx.export(
inference_model,
torch.zeros((1, 3, 1, 1)).to(model.device),
input_shape.to(model.device),
str(onnx_path),
opset_version=14,
dynamic_axes={"input": {0: "batch_size", 2: "height", 3: "weight"}, "output": {0: "batch_size"}},
dynamic_axes=dynamic_axes,
input_names=["input"],
output_names=["output"],
)
Expand All @@ -228,17 +235,20 @@ def export_to_onnx(


def export_to_openvino(
export_root: Path | str,
model: AnomalyModule,
export_root: Path | str,
input_size: tuple[int, int] | None = None,
transform: Transform | None = None,
ov_args: dict[str, Any] | None = None,
task: TaskType | None = None,
) -> Path:
"""Convert onnx model to OpenVINO IR.
Args:
export_root (Path): Path to the export folder.
model (AnomalyModule): AnomalyModule to export.
export_root (Path): Path to the export folder.
input_size (tuple[int, int] | None, optional): Input size of the model. Used for adding metadata to the IR.
Defaults to None.
transform (Transform, optional): Input transforms used for the model. If not provided, the transform is taken
from the model.
Defaults to ``None``.
Expand Down Expand Up @@ -289,7 +299,7 @@ def export_to_openvino(
... )
"""
model_path = export_to_onnx(model, export_root, transform, task, ExportType.OPENVINO)
model_path = export_to_onnx(model, export_root, input_size, transform, task, ExportType.OPENVINO)
ov_model_path = model_path.with_suffix(".xml")
ov_args = {} if ov_args is None else ov_args
if convert_model is not None and serialize is not None:
Expand Down
4 changes: 4 additions & 0 deletions src/anomalib/deploy/inferencers/openvino_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def predict(
msg = f"Input image must be a numpy array or a path to an image. Got {type(image)}"
raise TypeError(msg)

# Resize image to model input size if not dynamic
if self.input_blob.partial_shape[2:].is_static:
image = cv2.resize(image, tuple(self.input_blob.shape[2:][::-1]))

# Normalize numpy array to range [0, 1]
if image.dtype != np.float32:
image = image.astype(np.float32)
Expand Down
9 changes: 7 additions & 2 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ def export(
model: AnomalyModule,
export_type: ExportType,
export_root: str | Path | None = None,
input_size: tuple[int, int] | None = None,
transform: Transform | None = None,
ov_args: dict[str, Any] | None = None,
ckpt_path: str | Path | None = None,
Expand All @@ -851,6 +852,8 @@ def export(
export_type (ExportType): Export type.
export_root (str | Path | None, optional): Path to the output directory. If it is not set, the model is
exported to trainer.default_root_dir. Defaults to None.
input_size (tuple[int, int] | None, optional): A statis input shape for the model, which is exported to ONNX
and OpenVINO format. Defaults to None.
transform (Transform | None, optional): Input transform to include in the exported model. If not provided,
the engine will try to use the transform from the datamodule or dataset. Defaults to None.
ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer.
Expand All @@ -877,10 +880,10 @@ def export(
```python
anomalib export --model Padim --export_mode OPENVINO --data Visa --input_size "[256,256]"
```
4. You can also overrride OpenVINO model optimizer by adding the ``--mo_args.<key>`` arguments.
4. You can also overrride OpenVINO model optimizer by adding the ``--ov_args.<key>`` arguments.
```python
anomalib export --model Padim --export_mode OPENVINO --data Visa --input_size "[256,256]" \
--mo_args.compress_to_fp16 False
--ov_args.compress_to_fp16 False
```
"""
self._setup_trainer(model)
Expand All @@ -903,13 +906,15 @@ def export(
exported_model_path = export_to_onnx(
model=model,
export_root=export_root,
input_size=input_size,
transform=transform,
task=self.task,
)
elif export_type == ExportType.OPENVINO:
exported_model_path = export_to_openvino(
model=model,
export_root=export_root,
input_size=input_size,
transform=transform,
task=self.task,
ov_args=ov_args,
Expand Down

0 comments on commit 69d0c5b

Please sign in to comment.