Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimize for visual prompting to 2.0 #3040

Merged
merged 69 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 67 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
f111cf1
Fix return's shape when all predicted masks are zero
sungchul2 Feb 6, 2024
722918c
Swap order of prompts in segment_anything.py
sungchul2 Feb 20, 2024
f9e1254
Merge branch 'v2' into v2_add_vpm_export
sungchul2 Feb 20, 2024
50accde
Enable to export segment anything modules respectively
sungchul2 Feb 21, 2024
d3b9a9a
Update prompt order and refactoring
sungchul2 Feb 22, 2024
9b4b42e
Fix `return_single_mask` to True
sungchul2 Feb 22, 2024
54f31bd
Add OVIR inference logic
sungchul2 Feb 22, 2024
4bc4cfa
Updates for integration test
sungchul2 Feb 23, 2024
7671271
precommit
sungchul2 Feb 23, 2024
0155526
Update prompt getter
sungchul2 Feb 23, 2024
b1da2dc
(WIP) Update zero-shot
sungchul2 Feb 23, 2024
1da36e2
(WIP) Internalize preprocess into model
sungchul2 Feb 27, 2024
eed5e2c
Update reference infos to buffer
sungchul2 Feb 27, 2024
33f6dd1
Update reference_info path
sungchul2 Feb 27, 2024
2d48108
Enable standalone infer logic
sungchul2 Feb 27, 2024
af901e4
Update location of `-load_latest_reference_info` for OVModel
sungchul2 Feb 27, 2024
52a8068
Update recipes
sungchul2 Feb 28, 2024
6ac1ec5
Enable `infer` ov inference
sungchul2 Feb 28, 2024
20ba18e
Enable standalone `infer` logic on OVModel
sungchul2 Feb 28, 2024
9f4c3cd
precommit
sungchul2 Feb 28, 2024
a754c84
Update unittest
sungchul2 Feb 29, 2024
6d34e89
Update for integration test
sungchul2 Feb 29, 2024
81c1672
Merge branch 'v2' into v2_add_vpm_export
sungchul2 Feb 29, 2024
a41feb8
precommit
sungchul2 Feb 29, 2024
eccf7ae
Fix
sungchul2 Feb 29, 2024
2de2fd7
Fix
sungchul2 Feb 29, 2024
fc89558
Fix
sungchul2 Feb 29, 2024
6744dfe
Fix intg tests
sungchul2 Feb 29, 2024
2708d7b
Enable to update `export_args` to `deploy_cfg`
sungchul2 Feb 29, 2024
731447b
Update with walrus
sungchul2 Mar 4, 2024
03d9795
Update to use dict labels
sungchul2 Mar 4, 2024
c7f5f1e
Change openvino model names
sungchul2 Mar 4, 2024
b8d1b51
Update compatibility with zero-shot
sungchul2 Mar 4, 2024
f1aa67c
Refactoring for unnecessary assigned variables
sungchul2 Mar 4, 2024
180990a
Avoid repeatedly executing `torch.cat`
sungchul2 Mar 4, 2024
4220ded
precommit
sungchul2 Mar 4, 2024
a07f4c6
Fix unit test
sungchul2 Mar 4, 2024
27a05cd
Update variable name
sungchul2 Mar 4, 2024
979a01b
Add `example_inputs` in anomaly
sungchul2 Mar 4, 2024
14fa2c7
Fix unit test
sungchul2 Mar 4, 2024
25bb988
Fix
sungchul2 Mar 4, 2024
05bfdfa
Update `model_names` for visual prompting
sungchul2 Mar 5, 2024
05104a6
precommit
sungchul2 Mar 5, 2024
73f081a
Not to include other params in `example_inputs`
sungchul2 Mar 5, 2024
d9335de
Disable condition for visual prompting
sungchul2 Mar 5, 2024
97d674b
Update to `example_inputs`
sungchul2 Mar 5, 2024
0a9eaa1
Remove unused kwargs
sungchul2 Mar 5, 2024
9393dd4
Update
sungchul2 Mar 5, 2024
c86ddca
Remove unused parts
sungchul2 Mar 5, 2024
84ed9f4
Update exported models' names
sungchul2 Mar 5, 2024
2d75684
Remove `example_inputs`
sungchul2 Mar 5, 2024
a80aae2
Add `OTXVisualPromptingModelExporter`
sungchul2 Mar 5, 2024
6a405fe
Update overlapped region refinement
sungchul2 Feb 29, 2024
d4c9a50
Update `export_params`
sungchul2 Mar 6, 2024
1217721
Merge branch 'v2' into v2_add_vpm_optimize
sungchul2 Mar 6, 2024
af662f4
Enable optimize
sungchul2 Mar 6, 2024
e361b74
Add exportable code, but updating `demo.py` is required
sungchul2 Mar 7, 2024
6d515c4
Update model_type
sungchul2 Mar 7, 2024
033826b
Fix integration test
sungchul2 Mar 7, 2024
96f8147
Merge branch 'v2' into v2_add_vpm_optimize
sungchul2 Mar 7, 2024
38c8ba2
Add unit test
sungchul2 Mar 7, 2024
d248f81
Refactoring redundant parts
sungchul2 Mar 7, 2024
5ff1ff5
Revert exportable_code
sungchul2 Mar 7, 2024
9f812a7
Update unit test
sungchul2 Mar 7, 2024
7ceee75
Merge branch 'v2' into v2_add_vpm_optimize
sungchul2 Mar 7, 2024
9da32eb
Update unit test
sungchul2 Mar 7, 2024
3ffb11a
Fix unit test
sungchul2 Mar 7, 2024
c79ad8e
Temporarily fix integration test
sungchul2 Mar 8, 2024
e3d29ae
Revert to disable opening subprocess & add xfail for vpm tasks
sungchul2 Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,12 @@
used_points: dict[int, list[Tensor]],
threshold_iou: float = 0.8,
) -> None:
def _calculate_mask_iou(mask1: Tensor, mask2: Tensor) -> Tensor:
intersection = torch.logical_and(mask1, mask2).sum().item()
union = torch.logical_or(mask1, mask2).sum().item()
if union == 0:
def _calculate_mask_iou(mask1: Tensor, mask2: Tensor) -> tuple[float, Tensor | None]:
if (union := torch.logical_or(mask1, mask2).sum().item()) == 0:
# Avoid division by zero
return 0.0
return intersection / union
return 0.0, None

Check warning on line 406 in src/otx/algo/visual_prompting/zero_shot_segment_anything.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/visual_prompting/zero_shot_segment_anything.py#L406

Added line #L406 was not covered by tests
intersection = torch.logical_and(mask1, mask2)
return intersection.sum().item() / union, intersection

for (label, masks), (other_label, other_masks) in product(predicted_masks.items(), predicted_masks.items()):
if other_label <= label:
Expand All @@ -415,11 +414,20 @@
overlapped_label = []
overlapped_other_label = []
for (im, mask), (jm, other_mask) in product(enumerate(masks), enumerate(other_masks)):
if _calculate_mask_iou(mask, other_mask) > threshold_iou:
_mask_iou, _intersection = _calculate_mask_iou(mask, other_mask)
if _mask_iou > threshold_iou:
# compare overlapped regions between different labels and filter out the lower score
if used_points[label][im][2] > used_points[other_label][jm][2]:
overlapped_other_label.append(jm)
else:
overlapped_label.append(im)
elif _mask_iou > 0:
# refine the slightly overlapping region
overlapped_coords = torch.where(_intersection)
if used_points[label][im][2] > used_points[other_label][jm][2]:
other_mask[overlapped_coords] = 0.0

Check warning on line 428 in src/otx/algo/visual_prompting/zero_shot_segment_anything.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/visual_prompting/zero_shot_segment_anything.py#L426-L428

Added lines #L426 - L428 were not covered by tests
else:
mask[overlapped_coords] = 0.0

Check warning on line 430 in src/otx/algo/visual_prompting/zero_shot_segment_anything.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algo/visual_prompting/zero_shot_segment_anything.py#L430

Added line #L430 was not covered by tests

for im in sorted(list(set(overlapped_label)), reverse=True): # noqa: C414
masks.pop(im)
Expand Down
161 changes: 142 additions & 19 deletions src/otx/core/model/entity/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import pickle
from collections import defaultdict
from copy import deepcopy
from functools import partial
from itertools import product
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any, Literal

import cv2
import numpy as np
import torch
from openvino.model_api.models import Model
from torchvision import tv_tensors

from otx.core.data.entity.base import OTXBatchLossEntity, Points, T_OTXBatchPredEntityWithXAI
Expand All @@ -32,6 +32,11 @@
from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter
from otx.core.model.entity.base import OTXModel, OVModel

if TYPE_CHECKING:
from openvino.model_api.models import Model

from otx.core.data.module import OTXDataModule


class OTXVisualPromptingModel(
OTXModel[
Expand All @@ -57,13 +62,37 @@
export_params = super()._export_parameters
export_params["metadata"].update(
{
("model_info", "model_type"): "segment_anything",
("model_info", "model_type"): "Visual_Prompting",
("model_info", "task_type"): "visual_prompting",
},
)
export_params["input_size"] = (1, 3, self.model.image_size, self.model.image_size)
export_params["resize_mode"] = "fit_to_window"
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
export_params["mean"] = (123.675, 116.28, 103.53)
export_params["std"] = (58.395, 57.12, 57.375)
return export_params

@property
def _optimization_config(self) -> dict[str, Any]:
"""PTQ config for visual prompting models."""
return {
"model_type": "transformer",
"advanced_parameters": {
"activations_range_estimator_params": {
"min": {
"statistics_type": "QUANTILE",
"aggregator_type": "MIN",
"quantile_outlier_prob": "1e-4",
},
"max": {
"statistics_type": "QUANTILE",
"aggregator_type": "MAX",
"quantile_outlier_prob": "1e-4",
},
},
},
}

def _reset_prediction_layer(self, num_classes: int) -> None:
return

Expand Down Expand Up @@ -98,8 +127,9 @@
async_inference = False

basename: str = Path(model_name).name
model_type_name: str = "_".join(basename.split("_")[:2])
self.model_names: dict[str, str] = {
module: model_name.replace(basename, f"exported_model_{module}.xml")
module: model_name.replace(basename, f"{model_type_name}_{module}.xml")
for module in ["image_encoder", "decoder"]
}
super().__init__(
Expand All @@ -115,6 +145,7 @@
def _create_model(self) -> dict[str, Model]:
"""Create a OV model with help of Model API."""
from openvino.model_api.adapters import OpenvinoAdapter, create_core, get_user_config
from openvino.model_api.models import Model

ov_models: dict[str, Model] = {}

Expand Down Expand Up @@ -225,6 +256,90 @@
labels=[torch.cat(list(labels.values())) for labels in inputs.labels],
)

def optimize( # type: ignore[override]
self,
output_dir: Path,
data_module: OTXDataModule,
ptq_config: dict[str, Any] | None = None,
) -> dict[str, Path]:
"""Runs NNCF quantization."""
import nncf
import openvino

def check_if_quantized(model: openvino.Model) -> bool:
"""Checks if OpenVINO model is already quantized."""
nodes = model.get_ops()
return any(op.get_type_name() == "FakeQuantize" for op in nodes)

def transform_fn(
data_batch: VisualPromptingBatchDataEntity | ZeroShotVisualPromptingBatchDataEntity,
module: Literal["image_encoder", "decoder"],
) -> np.ndarray | dict[str, Any]:
images, _, prompts = self._customize_inputs(data_batch) # type: ignore[arg-type]

Check warning on line 278 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L278

Added line #L278 was not covered by tests

image = images[0]["images"] # use only the first image
if module == "image_encoder":

Check warning on line 281 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L280-L281

Added lines #L280 - L281 were not covered by tests
# resize
resized_image = self.model["image_encoder"].resize(

Check warning on line 283 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L283

Added line #L283 was not covered by tests
image[0],
(self.model["image_encoder"].w, self.model["image_encoder"].h),
)

# pad image if necessary because `fit_to_window` resize for python in modelapi doesn't support pad
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
pad_w = max(0, self.model["image_encoder"].w - resized_image.shape[1])
pad_h = max(0, self.model["image_encoder"].h - resized_image.shape[0])
resized_image = np.pad(

Check warning on line 291 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L289-L291

Added lines #L289 - L291 were not covered by tests
resized_image,
((0, pad_h), (0, pad_w), (0, 0)),
mode="constant",
constant_values=0,
)

# normalization
resized_image = self.model["image_encoder"].input_transform(resized_image)

Check warning on line 299 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L299

Added line #L299 was not covered by tests

# change layout from HWC to NCHW
return self.model["image_encoder"]._change_layout(resized_image) # noqa: SLF001

Check warning on line 302 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L302

Added line #L302 was not covered by tests

# obtain image embeddings from image encoder
image_embeddings = self.model["image_encoder"].infer_sync(image)

Check warning on line 305 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L305

Added line #L305 was not covered by tests
# use only the first prompt
prompt_for_optim = next(iter(prompts[0].values()))[0] if isinstance(prompts[0], dict) else prompts[0][0] # type: ignore[attr-defined]
prompt_for_optim.pop("label")
prompt_for_optim.update(**image_embeddings)
return prompt_for_optim

Check warning on line 310 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L307-L310

Added lines #L307 - L310 were not covered by tests

output_model_paths: dict[str, Path] = {}
for module in ["image_encoder", "decoder"]:
output_model_path = output_dir / (self._OPTIMIZED_MODEL_BASE_NAME + f"_{module}.xml")

ov_model = openvino.Core().read_model(self.model_names[module])
if check_if_quantized(ov_model):
msg = "Model is already optimized by PTQ"
raise RuntimeError(msg)

Check warning on line 319 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L318-L319

Added lines #L318 - L319 were not covered by tests

train_dataset = data_module.train_dataloader()

ptq_config_from_ir = self._read_ptq_config_from_ir(ov_model)
if ptq_config is not None:
ptq_config_from_ir.update(ptq_config)
ptq_config = ptq_config_from_ir
else:
ptq_config = ptq_config_from_ir

quantization_dataset = nncf.Dataset(train_dataset, partial(transform_fn, module=module)) # type: ignore[attr-defined]

compressed_model = nncf.quantize( # type: ignore[attr-defined]
ov_model,
quantization_dataset,
**ptq_config,
)

openvino.save_model(compressed_model, output_model_path)
output_model_paths[module] = output_model_path

return output_model_paths


class OVZeroShotVisualPromptingModel(OVVisualPromptingModel):
"""Zero-shot visual prompting model compatible for OpenVINO IR inference.
Expand Down Expand Up @@ -427,7 +542,7 @@
images: list[np.ndarray] = []
metas: list[dict[str, Any]] = []
processed_prompts: list[list[dict[str, Any]]] = []
for image, prompts, label, imgs_info in zip(
for image, prompts, labels, imgs_info in zip(
entity.images,
entity.prompts,
entity.labels,
Expand All @@ -442,22 +557,22 @@
if self.training:
points: list[np.ndarray] = []
bboxes: list[np.ndarray] = []
labels: dict[str, list[int]] = defaultdict(list)
for prompt in prompts:
_labels: dict[str, list[int]] = defaultdict(list)
for prompt, label in zip(prompts, labels):
if isinstance(prompt, tv_tensors.BoundingBoxes):
bboxes.append(prompt.cpu().numpy())
labels["bboxes"].append(label.cpu().numpy())
_labels["bboxes"].append(label.cpu().numpy())
elif isinstance(prompt, Points):
points.append(prompt.cpu().numpy())
labels["points"].append(label.cpu().numpy())
_labels["points"].append(label.cpu().numpy())

Check warning on line 567 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L567

Added line #L567 was not covered by tests

# preprocess decoder inputs
processed_prompts.append(
self.model["decoder"].preprocess(
{
"bboxes": bboxes,
"points": points,
"labels": labels,
"labels": _labels,
"orig_size": imgs_info.ori_shape,
},
),
Expand Down Expand Up @@ -820,16 +935,14 @@
used_points: dict[int, list[np.ndarray]],
threshold_iou: float = 0.8,
) -> None:
def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray:
def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> tuple[float, np.ndarray | None]:
assert mask1.ndim == 2 # noqa: S101
assert mask2.ndim == 2 # noqa: S101
intersection = np.logical_and(mask1, mask2).sum().item()
union = np.logical_or(mask1, mask2).sum().item()

# Avoid division by zero
if union == 0:
return 0.0
return intersection / union
if (union := np.logical_or(mask1, mask2).sum().item()) == 0:
return 0.0, None

Check warning on line 943 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L943

Added line #L943 was not covered by tests
intersection = np.logical_and(mask1, mask2)
return intersection.sum().item() / union, intersection

for (label, masks), (other_label, other_masks) in product(predicted_masks.items(), predicted_masks.items()):
if other_label <= label:
Expand All @@ -838,11 +951,19 @@
overlapped_label = []
overlapped_other_label = []
for (im, mask), (jm, other_mask) in product(enumerate(masks), enumerate(other_masks)):
if _calculate_mask_iou(mask, other_mask) > threshold_iou:
_mask_iou, _intersection = _calculate_mask_iou(mask, other_mask)
if _mask_iou > threshold_iou:
if used_points[label][im][2] > used_points[other_label][jm][2]:
overlapped_other_label.append(jm)
else:
overlapped_label.append(im)
elif _mask_iou > 0:
# refine the slightly overlapping region
overlapped_coords = np.where(_intersection)
if used_points[label][im][2] > used_points[other_label][jm][2]:
other_mask[overlapped_coords] = 0.0

Check warning on line 964 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L962-L964

Added lines #L962 - L964 were not covered by tests
else:
mask[overlapped_coords] = 0.0

Check warning on line 966 in src/otx/core/model/entity/visual_prompting.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/model/entity/visual_prompting.py#L966

Added line #L966 was not covered by tests

for im in sorted(set(overlapped_label), reverse=True):
masks.pop(im)
Expand All @@ -861,7 +982,9 @@
indices = range(k)
partitioned_ind = np.argpartition(x, k, axis=axis).take(indices=indices, axis=axis)
partitioned_scores = np.take_along_axis(x, partitioned_ind, axis=axis)
sorted_trunc_ind = np.flip(np.argsort(partitioned_scores, axis=axis), axis=axis)
sorted_trunc_ind = np.argsort(partitioned_scores, axis=axis)
if largest:
sorted_trunc_ind = np.flip(sorted_trunc_ind, axis=axis)
ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis)
scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis)
return scores, ind
Expand Down
13 changes: 6 additions & 7 deletions tests/integration/cli/test_export_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device:
assert (tmp_path_test / "outputs").exists()

# 5) test optimize
if task in ("visual_prompting", "zero_shot_visual_prompting"):
log.info(f"{task} will support optimize in the future. Skip the test.")
return

command_cfg = [
"otx",
"optimize",
Expand All @@ -238,7 +234,10 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device:
key=lambda p: p.stat().st_mtime,
)
assert latest_dir.exists()
exported_model_path = str(latest_dir / "optimized_model.xml")
if task in ("visual_prompting", "zero_shot_visual_prompting"):
exported_model_path = str(latest_dir / "optimized_model_decoder.xml")
else:
exported_model_path = str(latest_dir / "optimized_model.xml")

# 6) test optimized model
tmp_path_test = run_cli_test(export_test_recipe, exported_model_path, Path("outputs") / "nncf_ptq", "cpu")
Expand Down Expand Up @@ -276,8 +275,8 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device:
msg = f"Recipe: {recipe}, (torch_accuracy, ov_accuracy): {torch_acc} , {ov_acc}"
log.info(msg)

# Not compare w/ instance segmentation because training isn't able to be deterministic, which can lead to unstable test result.
if "maskrcnn_efficientnetb2b" in recipe:
# Not compare w/ instance segmentation and visual prompting tasks because training isn't able to be deterministic, which can lead to unstable test result.
if "maskrcnn_efficientnetb2b" in recipe or task in ("visual_prompting", "zero_shot_visual_prompting"):
return

if "multi_label_cls/mobilenet_v3_large_light" in request.node.name:
Expand Down
Loading
Loading