Skip to content

Commit

Permalink
Refactor Visual Prompting OV model to use MAPI (#3638)
Browse files Browse the repository at this point in the history
* Use SAM OV models from MAPI

* Disable throughput for vpt becase of sync infer

* Update vpt exporter

* Replace infer in ZSL with MAPI

* Replace infer in ZSL with MAPI

* Fix input name for ZSL wrapper

* Fix merge artifacts

* Cleanup

* Adapt ZSL optimize to MAPI

* Del unused code

* Refactor infer in zsl

* Update SAM and ZSL MAPI usage

* Adapt to changes in MAPI

* Fix import sorting

* Bump MAPI version

* Del outdated unit tests

* Fix tiling utest

* Fix vpt onnx export

* Fix visual prompting unit tests

* Fix linters

* Fix setup script

* Add todo
  • Loading branch information
sovrasov authored Jul 3, 2024
1 parent 1fde0c8 commit 61b4db0
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 988 deletions.
5 changes: 1 addition & 4 deletions src/otx/algo/visual_prompting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Module for OTX visual prompting models."""

from . import backbones, decoders, encoders
from .openvino_models import VisualPromptingDecoder, VisualPromptingImageEncoder
from .segment_anything import OTXSegmentAnything, SegmentAnything
from .zero_shot_segment_anything import OTXZeroShotSegmentAnything, ZeroShotSegmentAnything

Expand All @@ -16,6 +15,4 @@
"SegmentAnything",
"OTXZeroShotSegmentAnything",
"ZeroShotSegmentAnything",
"VisualPromptingImageEncoder",
"VisualPromptingDecoder",
]
150 changes: 0 additions & 150 deletions src/otx/algo/visual_prompting/openvino_models.py

This file was deleted.

25 changes: 19 additions & 6 deletions src/otx/core/exporter/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import logging as log
import tempfile
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

Expand Down Expand Up @@ -61,22 +62,23 @@ def export( # type: ignore[override]
log.warning(msg)
fn = self.to_openvino
elif export_format == OTXExportFormatType.ONNX:
fn = self.to_onnx
fn = partial(self.to_onnx, embed_metadata=True)
else:
msg = f"Unsupported export format: {export_format}"
raise ValueError(msg)

return { # type: ignore[return-value]
module: fn(models[module], output_dir, f"{base_model_name}_{module}", precision)
module: fn(models[module], output_dir, f"{base_model_name}_{module}", precision, model_type=f"sam_{module}")
for module in ["image_encoder", "decoder"]
}

def to_openvino(
self,
model: OTXModel,
model: OTXModel | torch.nn.Module,
output_dir: Path,
base_model_name: str = "exported_model",
precision: OTXPrecisionType = OTXPrecisionType.FP32,
model_type: str = "sam",
) -> Path:
"""Export to OpenVINO Intermediate Representation format.
Expand All @@ -93,12 +95,17 @@ def to_openvino(
tmp_dir,
base_model_name,
OTXPrecisionType.FP32,
False,
embed_metadata=False,
)
exported_model = openvino.convert_model(tmp_dir / (base_model_name + ".onnx"))

exported_model = self._postprocess_openvino_model(exported_model)

if self.metadata is not None:
export_metadata = self._extend_model_metadata(self.metadata)
export_metadata[("model_info", "model_type")] = model_type
exported_model = self._embed_openvino_ir_metadata(exported_model, export_metadata)

save_path = output_dir / (base_model_name + ".xml")
openvino.save_model(exported_model, save_path, compress_to_fp16=(precision == OTXPrecisionType.FP16))
log.info("Converting to OpenVINO is done.")
Expand All @@ -107,11 +114,12 @@ def to_openvino(

def to_onnx(
self,
model: OTXModel,
model: OTXModel | torch.nn.Module,
output_dir: Path,
base_model_name: str = "exported_model",
precision: OTXPrecisionType = OTXPrecisionType.FP32,
embed_metadata: bool = True,
model_type: str = "sam",
) -> Path:
"""Export the given PyTorch model to ONNX format and save it to the specified output directory.
Expand All @@ -136,7 +144,12 @@ def to_onnx(
)

onnx_model = onnx.load(save_path)
onnx_model = self._postprocess_onnx_model(onnx_model, embed_metadata, precision)
onnx_model = self._postprocess_onnx_model(onnx_model, False, precision)

if self.metadata is not None and embed_metadata:
export_metadata = self._extend_model_metadata(self.metadata)
export_metadata[("model_info", "model_type")] = model_type
onnx_model = self._embed_onnx_metadata(onnx_model, export_metadata)

onnx.save(onnx_model, save_path)
log.info("Converting to ONNX is done.")
Expand Down
Loading

0 comments on commit 61b4db0

Please sign in to comment.