Skip to content

Commit

Permalink
Add optimize for visual prompting to 2.0 (#3040)
Browse files Browse the repository at this point in the history
* Fix return's shape when all predicted masks are zero

* Swap order of prompts in segment_anything.py

* Enable to export segment anything modules respectively

* Update prompt order and refactoring

* Fix `return_single_mask` to True

* Add OVIR inference logic

* Updates for integration test

* precommit

* Update prompt getter

* (WIP) Update zero-shot

* (WIP) Internalize preprocess into model

* Update reference infos to buffer

* Update reference_info path

* Enable standalone infer logic

* Update location of `-load_latest_reference_info` for OVModel

* Update recipes

* Enable `infer` ov inference

* Enable standalone `infer` logic on OVModel

* precommit

* Update unittest

* Update for integration test

* precommit

* Fix

* Fix

* Fix

* Fix intg tests

* Enable to update `export_args` to `deploy_cfg`

* Update with walrus

* Update to use dict labels

* Change openvino model names

* Update compatibility with zero-shot

* Refactoring for unnecessary assigned variables

* Avoid repeatedly executing `torch.cat`

* precommit

* Fix unit test

* Update variable name

* Add `example_inputs` in anomaly

* Fix unit test

* Fix

* Update `model_names` for visual prompting

* precommit

* Not to include other params in `example_inputs`

* Disable condition for visual prompting

* Update to `example_inputs`

* Remove unused kwargs

* Update

* Remove unused parts

* Update exported models' names

* Remove `example_inputs`

* Add `OTXVisualPromptingModelExporter`

* Update overlapped region refinement

* Update `export_params`

* Enable optimize

* Add exportable code, but updating `demo.py` is required

* Update model_type

* Fix integration test
  • Loading branch information
sungchul2 authored Mar 11, 2024
1 parent 196dc83 commit 506a22f
Show file tree
Hide file tree
Showing 6 changed files with 967 additions and 32 deletions.
22 changes: 15 additions & 7 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,12 @@ def _inspect_overlapping_areas(
used_points: dict[int, list[Tensor]],
threshold_iou: float = 0.8,
) -> None:
def _calculate_mask_iou(mask1: Tensor, mask2: Tensor) -> Tensor:
intersection = torch.logical_and(mask1, mask2).sum().item()
union = torch.logical_or(mask1, mask2).sum().item()
if union == 0:
def _calculate_mask_iou(mask1: Tensor, mask2: Tensor) -> tuple[float, Tensor | None]:
if (union := torch.logical_or(mask1, mask2).sum().item()) == 0:
# Avoid division by zero
return 0.0
return intersection / union
return 0.0, None
intersection = torch.logical_and(mask1, mask2)
return intersection.sum().item() / union, intersection

for (label, masks), (other_label, other_masks) in product(predicted_masks.items(), predicted_masks.items()):
if other_label <= label:
Expand All @@ -415,11 +414,20 @@ def _calculate_mask_iou(mask1: Tensor, mask2: Tensor) -> Tensor:
overlapped_label = []
overlapped_other_label = []
for (im, mask), (jm, other_mask) in product(enumerate(masks), enumerate(other_masks)):
if _calculate_mask_iou(mask, other_mask) > threshold_iou:
_mask_iou, _intersection = _calculate_mask_iou(mask, other_mask)
if _mask_iou > threshold_iou:
# compare overlapped regions between different labels and filter out the lower score
if used_points[label][im][2] > used_points[other_label][jm][2]:
overlapped_other_label.append(jm)
else:
overlapped_label.append(im)
elif _mask_iou > 0:
# refine the slightly overlapping region
overlapped_coords = torch.where(_intersection)
if used_points[label][im][2] > used_points[other_label][jm][2]:
other_mask[overlapped_coords] = 0.0
else:
mask[overlapped_coords] = 0.0

for im in sorted(list(set(overlapped_label)), reverse=True): # noqa: C414
masks.pop(im)
Expand Down
161 changes: 142 additions & 19 deletions src/otx/core/model/entity/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import pickle
from collections import defaultdict
from copy import deepcopy
from functools import partial
from itertools import product
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any, Literal

import cv2
import numpy as np
import torch
from openvino.model_api.models import Model
from torchvision import tv_tensors

from otx.core.data.entity.base import OTXBatchLossEntity, Points, T_OTXBatchPredEntityWithXAI
Expand All @@ -32,6 +32,11 @@
from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter
from otx.core.model.entity.base import OTXModel, OVModel

if TYPE_CHECKING:
from openvino.model_api.models import Model

from otx.core.data.module import OTXDataModule


class OTXVisualPromptingModel(
OTXModel[
Expand All @@ -57,13 +62,37 @@ def _export_parameters(self) -> dict[str, Any]:
export_params = super()._export_parameters
export_params["metadata"].update(
{
("model_info", "model_type"): "segment_anything",
("model_info", "model_type"): "Visual_Prompting",
("model_info", "task_type"): "visual_prompting",
},
)
export_params["input_size"] = (1, 3, self.model.image_size, self.model.image_size)
export_params["resize_mode"] = "fit_to_window"
export_params["mean"] = (123.675, 116.28, 103.53)
export_params["std"] = (58.395, 57.12, 57.375)
return export_params

@property
def _optimization_config(self) -> dict[str, Any]:
"""PTQ config for visual prompting models."""
return {
"model_type": "transformer",
"advanced_parameters": {
"activations_range_estimator_params": {
"min": {
"statistics_type": "QUANTILE",
"aggregator_type": "MIN",
"quantile_outlier_prob": "1e-4",
},
"max": {
"statistics_type": "QUANTILE",
"aggregator_type": "MAX",
"quantile_outlier_prob": "1e-4",
},
},
},
}

def _reset_prediction_layer(self, num_classes: int) -> None:
return

Expand Down Expand Up @@ -98,8 +127,9 @@ def __init__(
async_inference = False

basename: str = Path(model_name).name
model_type_name: str = "_".join(basename.split("_")[:2])
self.model_names: dict[str, str] = {
module: model_name.replace(basename, f"exported_model_{module}.xml")
module: model_name.replace(basename, f"{model_type_name}_{module}.xml")
for module in ["image_encoder", "decoder"]
}
super().__init__(
Expand All @@ -115,6 +145,7 @@ def __init__(
def _create_model(self) -> dict[str, Model]:
"""Create a OV model with help of Model API."""
from openvino.model_api.adapters import OpenvinoAdapter, create_core, get_user_config
from openvino.model_api.models import Model

ov_models: dict[str, Model] = {}

Expand Down Expand Up @@ -225,6 +256,90 @@ def _customize_outputs(
labels=[torch.cat(list(labels.values())) for labels in inputs.labels],
)

def optimize( # type: ignore[override]
self,
output_dir: Path,
data_module: OTXDataModule,
ptq_config: dict[str, Any] | None = None,
) -> dict[str, Path]:
"""Runs NNCF quantization."""
import nncf
import openvino

def check_if_quantized(model: openvino.Model) -> bool:
"""Checks if OpenVINO model is already quantized."""
nodes = model.get_ops()
return any(op.get_type_name() == "FakeQuantize" for op in nodes)

def transform_fn(
data_batch: VisualPromptingBatchDataEntity | ZeroShotVisualPromptingBatchDataEntity,
module: Literal["image_encoder", "decoder"],
) -> np.ndarray | dict[str, Any]:
images, _, prompts = self._customize_inputs(data_batch) # type: ignore[arg-type]

image = images[0]["images"] # use only the first image
if module == "image_encoder":
# resize
resized_image = self.model["image_encoder"].resize(
image[0],
(self.model["image_encoder"].w, self.model["image_encoder"].h),
)

# pad image if necessary because `fit_to_window` resize for python in modelapi doesn't support pad
pad_w = max(0, self.model["image_encoder"].w - resized_image.shape[1])
pad_h = max(0, self.model["image_encoder"].h - resized_image.shape[0])
resized_image = np.pad(
resized_image,
((0, pad_h), (0, pad_w), (0, 0)),
mode="constant",
constant_values=0,
)

# normalization
resized_image = self.model["image_encoder"].input_transform(resized_image)

# change layout from HWC to NCHW
return self.model["image_encoder"]._change_layout(resized_image) # noqa: SLF001

# obtain image embeddings from image encoder
image_embeddings = self.model["image_encoder"].infer_sync(image)
# use only the first prompt
prompt_for_optim = next(iter(prompts[0].values()))[0] if isinstance(prompts[0], dict) else prompts[0][0] # type: ignore[attr-defined]
prompt_for_optim.pop("label")
prompt_for_optim.update(**image_embeddings)
return prompt_for_optim

output_model_paths: dict[str, Path] = {}
for module in ["image_encoder", "decoder"]:
output_model_path = output_dir / (self._OPTIMIZED_MODEL_BASE_NAME + f"_{module}.xml")

ov_model = openvino.Core().read_model(self.model_names[module])
if check_if_quantized(ov_model):
msg = "Model is already optimized by PTQ"
raise RuntimeError(msg)

train_dataset = data_module.train_dataloader()

ptq_config_from_ir = self._read_ptq_config_from_ir(ov_model)
if ptq_config is not None:
ptq_config_from_ir.update(ptq_config)
ptq_config = ptq_config_from_ir
else:
ptq_config = ptq_config_from_ir

quantization_dataset = nncf.Dataset(train_dataset, partial(transform_fn, module=module)) # type: ignore[attr-defined]

compressed_model = nncf.quantize( # type: ignore[attr-defined]
ov_model,
quantization_dataset,
**ptq_config,
)

openvino.save_model(compressed_model, output_model_path)
output_model_paths[module] = output_model_path

return output_model_paths


class OVZeroShotVisualPromptingModel(OVVisualPromptingModel):
"""Zero-shot visual prompting model compatible for OpenVINO IR inference.
Expand Down Expand Up @@ -427,7 +542,7 @@ def _customize_inputs( # type: ignore[override]
images: list[np.ndarray] = []
metas: list[dict[str, Any]] = []
processed_prompts: list[list[dict[str, Any]]] = []
for image, prompts, label, imgs_info in zip(
for image, prompts, labels, imgs_info in zip(
entity.images,
entity.prompts,
entity.labels,
Expand All @@ -442,22 +557,22 @@ def _customize_inputs( # type: ignore[override]
if self.training:
points: list[np.ndarray] = []
bboxes: list[np.ndarray] = []
labels: dict[str, list[int]] = defaultdict(list)
for prompt in prompts:
_labels: dict[str, list[int]] = defaultdict(list)
for prompt, label in zip(prompts, labels):
if isinstance(prompt, tv_tensors.BoundingBoxes):
bboxes.append(prompt.cpu().numpy())
labels["bboxes"].append(label.cpu().numpy())
_labels["bboxes"].append(label.cpu().numpy())
elif isinstance(prompt, Points):
points.append(prompt.cpu().numpy())
labels["points"].append(label.cpu().numpy())
_labels["points"].append(label.cpu().numpy())

# preprocess decoder inputs
processed_prompts.append(
self.model["decoder"].preprocess(
{
"bboxes": bboxes,
"points": points,
"labels": labels,
"labels": _labels,
"orig_size": imgs_info.ori_shape,
},
),
Expand Down Expand Up @@ -820,16 +935,14 @@ def _inspect_overlapping_areas(
used_points: dict[int, list[np.ndarray]],
threshold_iou: float = 0.8,
) -> None:
def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray:
def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> tuple[float, np.ndarray | None]:
assert mask1.ndim == 2 # noqa: S101
assert mask2.ndim == 2 # noqa: S101
intersection = np.logical_and(mask1, mask2).sum().item()
union = np.logical_or(mask1, mask2).sum().item()

# Avoid division by zero
if union == 0:
return 0.0
return intersection / union
if (union := np.logical_or(mask1, mask2).sum().item()) == 0:
return 0.0, None
intersection = np.logical_and(mask1, mask2)
return intersection.sum().item() / union, intersection

for (label, masks), (other_label, other_masks) in product(predicted_masks.items(), predicted_masks.items()):
if other_label <= label:
Expand All @@ -838,11 +951,19 @@ def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray:
overlapped_label = []
overlapped_other_label = []
for (im, mask), (jm, other_mask) in product(enumerate(masks), enumerate(other_masks)):
if _calculate_mask_iou(mask, other_mask) > threshold_iou:
_mask_iou, _intersection = _calculate_mask_iou(mask, other_mask)
if _mask_iou > threshold_iou:
if used_points[label][im][2] > used_points[other_label][jm][2]:
overlapped_other_label.append(jm)
else:
overlapped_label.append(im)
elif _mask_iou > 0:
# refine the slightly overlapping region
overlapped_coords = np.where(_intersection)
if used_points[label][im][2] > used_points[other_label][jm][2]:
other_mask[overlapped_coords] = 0.0
else:
mask[overlapped_coords] = 0.0

for im in sorted(set(overlapped_label), reverse=True):
masks.pop(im)
Expand All @@ -861,7 +982,9 @@ def _topk_numpy(self, x: np.ndarray, k: int, axis: int = -1, largest: bool = Tru
indices = range(k)
partitioned_ind = np.argpartition(x, k, axis=axis).take(indices=indices, axis=axis)
partitioned_scores = np.take_along_axis(x, partitioned_ind, axis=axis)
sorted_trunc_ind = np.flip(np.argsort(partitioned_scores, axis=axis), axis=axis)
sorted_trunc_ind = np.argsort(partitioned_scores, axis=axis)
if largest:
sorted_trunc_ind = np.flip(sorted_trunc_ind, axis=axis)
ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis)
scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis)
return scores, ind
Expand Down
14 changes: 9 additions & 5 deletions tests/integration/cli/test_export_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device:

# 5) test optimize
if task in ("visual_prompting", "zero_shot_visual_prompting"):
log.info(f"{task} will support optimize in the future. Skip the test.")
return
pytest.xfail(
"Optimize for visual prompting and zero shot visual prompting yields segmentation fault after optimize.",
)

command_cfg = [
"otx",
Expand All @@ -238,7 +239,10 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device:
key=lambda p: p.stat().st_mtime,
)
assert latest_dir.exists()
exported_model_path = str(latest_dir / "optimized_model.xml")
if task in ("visual_prompting", "zero_shot_visual_prompting"):
exported_model_path = str(latest_dir / "optimized_model_decoder.xml")
else:
exported_model_path = str(latest_dir / "optimized_model.xml")

# 6) test optimized model
tmp_path_test = run_cli_test(export_test_recipe, exported_model_path, Path("outputs") / "nncf_ptq", "cpu")
Expand Down Expand Up @@ -276,8 +280,8 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device:
msg = f"Recipe: {recipe}, (torch_accuracy, ov_accuracy): {torch_acc} , {ov_acc}"
log.info(msg)

# Not compare w/ instance segmentation because training isn't able to be deterministic, which can lead to unstable test result.
if "maskrcnn_efficientnetb2b" in recipe:
# Not compare w/ instance segmentation and visual prompting tasks because training isn't able to be deterministic, which can lead to unstable test result.
if "maskrcnn_efficientnetb2b" in recipe or task in ("visual_prompting", "zero_shot_visual_prompting"):
return

if "multi_label_cls/mobilenet_v3_large_light" in request.node.name:
Expand Down
Loading

0 comments on commit 506a22f

Please sign in to comment.