diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 72be3e36f62..8e64aea639b 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -400,13 +400,12 @@ def _inspect_overlapping_areas( used_points: dict[int, list[Tensor]], threshold_iou: float = 0.8, ) -> None: - def _calculate_mask_iou(mask1: Tensor, mask2: Tensor) -> Tensor: - intersection = torch.logical_and(mask1, mask2).sum().item() - union = torch.logical_or(mask1, mask2).sum().item() - if union == 0: + def _calculate_mask_iou(mask1: Tensor, mask2: Tensor) -> tuple[float, Tensor | None]: + if (union := torch.logical_or(mask1, mask2).sum().item()) == 0: # Avoid division by zero - return 0.0 - return intersection / union + return 0.0, None + intersection = torch.logical_and(mask1, mask2) + return intersection.sum().item() / union, intersection for (label, masks), (other_label, other_masks) in product(predicted_masks.items(), predicted_masks.items()): if other_label <= label: @@ -415,11 +414,20 @@ def _calculate_mask_iou(mask1: Tensor, mask2: Tensor) -> Tensor: overlapped_label = [] overlapped_other_label = [] for (im, mask), (jm, other_mask) in product(enumerate(masks), enumerate(other_masks)): - if _calculate_mask_iou(mask, other_mask) > threshold_iou: + _mask_iou, _intersection = _calculate_mask_iou(mask, other_mask) + if _mask_iou > threshold_iou: + # compare overlapped regions between different labels and filter out the lower score if used_points[label][im][2] > used_points[other_label][jm][2]: overlapped_other_label.append(jm) else: overlapped_label.append(im) + elif _mask_iou > 0: + # refine the slightly overlapping region + overlapped_coords = torch.where(_intersection) + if used_points[label][im][2] > used_points[other_label][jm][2]: + other_mask[overlapped_coords] = 0.0 + else: + mask[overlapped_coords] = 0.0 for im in sorted(list(set(overlapped_label)), reverse=True): # noqa: C414 masks.pop(im) diff --git a/src/otx/core/model/entity/visual_prompting.py b/src/otx/core/model/entity/visual_prompting.py index 193edef961a..9c59cab39cc 100644 --- a/src/otx/core/model/entity/visual_prompting.py +++ b/src/otx/core/model/entity/visual_prompting.py @@ -10,14 +10,14 @@ import pickle from collections import defaultdict from copy import deepcopy +from functools import partial from itertools import product from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any, Literal import cv2 import numpy as np import torch -from openvino.model_api.models import Model from torchvision import tv_tensors from otx.core.data.entity.base import OTXBatchLossEntity, Points, T_OTXBatchPredEntityWithXAI @@ -32,6 +32,11 @@ from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter from otx.core.model.entity.base import OTXModel, OVModel +if TYPE_CHECKING: + from openvino.model_api.models import Model + + from otx.core.data.module import OTXDataModule + class OTXVisualPromptingModel( OTXModel[ @@ -57,13 +62,37 @@ def _export_parameters(self) -> dict[str, Any]: export_params = super()._export_parameters export_params["metadata"].update( { - ("model_info", "model_type"): "segment_anything", + ("model_info", "model_type"): "Visual_Prompting", ("model_info", "task_type"): "visual_prompting", }, ) export_params["input_size"] = (1, 3, self.model.image_size, self.model.image_size) + export_params["resize_mode"] = "fit_to_window" + export_params["mean"] = (123.675, 116.28, 103.53) + export_params["std"] = (58.395, 57.12, 57.375) return export_params + @property + def _optimization_config(self) -> dict[str, Any]: + """PTQ config for visual prompting models.""" + return { + "model_type": "transformer", + "advanced_parameters": { + "activations_range_estimator_params": { + "min": { + "statistics_type": "QUANTILE", + "aggregator_type": "MIN", + "quantile_outlier_prob": "1e-4", + }, + "max": { + "statistics_type": "QUANTILE", + "aggregator_type": "MAX", + "quantile_outlier_prob": "1e-4", + }, + }, + }, + } + def _reset_prediction_layer(self, num_classes: int) -> None: return @@ -98,8 +127,9 @@ def __init__( async_inference = False basename: str = Path(model_name).name + model_type_name: str = "_".join(basename.split("_")[:2]) self.model_names: dict[str, str] = { - module: model_name.replace(basename, f"exported_model_{module}.xml") + module: model_name.replace(basename, f"{model_type_name}_{module}.xml") for module in ["image_encoder", "decoder"] } super().__init__( @@ -115,6 +145,7 @@ def __init__( def _create_model(self) -> dict[str, Model]: """Create a OV model with help of Model API.""" from openvino.model_api.adapters import OpenvinoAdapter, create_core, get_user_config + from openvino.model_api.models import Model ov_models: dict[str, Model] = {} @@ -225,6 +256,90 @@ def _customize_outputs( labels=[torch.cat(list(labels.values())) for labels in inputs.labels], ) + def optimize( # type: ignore[override] + self, + output_dir: Path, + data_module: OTXDataModule, + ptq_config: dict[str, Any] | None = None, + ) -> dict[str, Path]: + """Runs NNCF quantization.""" + import nncf + import openvino + + def check_if_quantized(model: openvino.Model) -> bool: + """Checks if OpenVINO model is already quantized.""" + nodes = model.get_ops() + return any(op.get_type_name() == "FakeQuantize" for op in nodes) + + def transform_fn( + data_batch: VisualPromptingBatchDataEntity | ZeroShotVisualPromptingBatchDataEntity, + module: Literal["image_encoder", "decoder"], + ) -> np.ndarray | dict[str, Any]: + images, _, prompts = self._customize_inputs(data_batch) # type: ignore[arg-type] + + image = images[0]["images"] # use only the first image + if module == "image_encoder": + # resize + resized_image = self.model["image_encoder"].resize( + image[0], + (self.model["image_encoder"].w, self.model["image_encoder"].h), + ) + + # pad image if necessary because `fit_to_window` resize for python in modelapi doesn't support pad + pad_w = max(0, self.model["image_encoder"].w - resized_image.shape[1]) + pad_h = max(0, self.model["image_encoder"].h - resized_image.shape[0]) + resized_image = np.pad( + resized_image, + ((0, pad_h), (0, pad_w), (0, 0)), + mode="constant", + constant_values=0, + ) + + # normalization + resized_image = self.model["image_encoder"].input_transform(resized_image) + + # change layout from HWC to NCHW + return self.model["image_encoder"]._change_layout(resized_image) # noqa: SLF001 + + # obtain image embeddings from image encoder + image_embeddings = self.model["image_encoder"].infer_sync(image) + # use only the first prompt + prompt_for_optim = next(iter(prompts[0].values()))[0] if isinstance(prompts[0], dict) else prompts[0][0] # type: ignore[attr-defined] + prompt_for_optim.pop("label") + prompt_for_optim.update(**image_embeddings) + return prompt_for_optim + + output_model_paths: dict[str, Path] = {} + for module in ["image_encoder", "decoder"]: + output_model_path = output_dir / (self._OPTIMIZED_MODEL_BASE_NAME + f"_{module}.xml") + + ov_model = openvino.Core().read_model(self.model_names[module]) + if check_if_quantized(ov_model): + msg = "Model is already optimized by PTQ" + raise RuntimeError(msg) + + train_dataset = data_module.train_dataloader() + + ptq_config_from_ir = self._read_ptq_config_from_ir(ov_model) + if ptq_config is not None: + ptq_config_from_ir.update(ptq_config) + ptq_config = ptq_config_from_ir + else: + ptq_config = ptq_config_from_ir + + quantization_dataset = nncf.Dataset(train_dataset, partial(transform_fn, module=module)) # type: ignore[attr-defined] + + compressed_model = nncf.quantize( # type: ignore[attr-defined] + ov_model, + quantization_dataset, + **ptq_config, + ) + + openvino.save_model(compressed_model, output_model_path) + output_model_paths[module] = output_model_path + + return output_model_paths + class OVZeroShotVisualPromptingModel(OVVisualPromptingModel): """Zero-shot visual prompting model compatible for OpenVINO IR inference. @@ -427,7 +542,7 @@ def _customize_inputs( # type: ignore[override] images: list[np.ndarray] = [] metas: list[dict[str, Any]] = [] processed_prompts: list[list[dict[str, Any]]] = [] - for image, prompts, label, imgs_info in zip( + for image, prompts, labels, imgs_info in zip( entity.images, entity.prompts, entity.labels, @@ -442,14 +557,14 @@ def _customize_inputs( # type: ignore[override] if self.training: points: list[np.ndarray] = [] bboxes: list[np.ndarray] = [] - labels: dict[str, list[int]] = defaultdict(list) - for prompt in prompts: + _labels: dict[str, list[int]] = defaultdict(list) + for prompt, label in zip(prompts, labels): if isinstance(prompt, tv_tensors.BoundingBoxes): bboxes.append(prompt.cpu().numpy()) - labels["bboxes"].append(label.cpu().numpy()) + _labels["bboxes"].append(label.cpu().numpy()) elif isinstance(prompt, Points): points.append(prompt.cpu().numpy()) - labels["points"].append(label.cpu().numpy()) + _labels["points"].append(label.cpu().numpy()) # preprocess decoder inputs processed_prompts.append( @@ -457,7 +572,7 @@ def _customize_inputs( # type: ignore[override] { "bboxes": bboxes, "points": points, - "labels": labels, + "labels": _labels, "orig_size": imgs_info.ori_shape, }, ), @@ -820,16 +935,14 @@ def _inspect_overlapping_areas( used_points: dict[int, list[np.ndarray]], threshold_iou: float = 0.8, ) -> None: - def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray: + def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> tuple[float, np.ndarray | None]: assert mask1.ndim == 2 # noqa: S101 assert mask2.ndim == 2 # noqa: S101 - intersection = np.logical_and(mask1, mask2).sum().item() - union = np.logical_or(mask1, mask2).sum().item() - # Avoid division by zero - if union == 0: - return 0.0 - return intersection / union + if (union := np.logical_or(mask1, mask2).sum().item()) == 0: + return 0.0, None + intersection = np.logical_and(mask1, mask2) + return intersection.sum().item() / union, intersection for (label, masks), (other_label, other_masks) in product(predicted_masks.items(), predicted_masks.items()): if other_label <= label: @@ -838,11 +951,19 @@ def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray: overlapped_label = [] overlapped_other_label = [] for (im, mask), (jm, other_mask) in product(enumerate(masks), enumerate(other_masks)): - if _calculate_mask_iou(mask, other_mask) > threshold_iou: + _mask_iou, _intersection = _calculate_mask_iou(mask, other_mask) + if _mask_iou > threshold_iou: if used_points[label][im][2] > used_points[other_label][jm][2]: overlapped_other_label.append(jm) else: overlapped_label.append(im) + elif _mask_iou > 0: + # refine the slightly overlapping region + overlapped_coords = np.where(_intersection) + if used_points[label][im][2] > used_points[other_label][jm][2]: + other_mask[overlapped_coords] = 0.0 + else: + mask[overlapped_coords] = 0.0 for im in sorted(set(overlapped_label), reverse=True): masks.pop(im) @@ -861,7 +982,9 @@ def _topk_numpy(self, x: np.ndarray, k: int, axis: int = -1, largest: bool = Tru indices = range(k) partitioned_ind = np.argpartition(x, k, axis=axis).take(indices=indices, axis=axis) partitioned_scores = np.take_along_axis(x, partitioned_ind, axis=axis) - sorted_trunc_ind = np.flip(np.argsort(partitioned_scores, axis=axis), axis=axis) + sorted_trunc_ind = np.argsort(partitioned_scores, axis=axis) + if largest: + sorted_trunc_ind = np.flip(sorted_trunc_ind, axis=axis) ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis) scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis) return scores, ind diff --git a/tests/integration/cli/test_export_inference.py b/tests/integration/cli/test_export_inference.py index 4cb34f10301..f0bec76a94d 100644 --- a/tests/integration/cli/test_export_inference.py +++ b/tests/integration/cli/test_export_inference.py @@ -211,8 +211,9 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: # 5) test optimize if task in ("visual_prompting", "zero_shot_visual_prompting"): - log.info(f"{task} will support optimize in the future. Skip the test.") - return + pytest.xfail( + "Optimize for visual prompting and zero shot visual prompting yields segmentation fault after optimize.", + ) command_cfg = [ "otx", @@ -238,7 +239,10 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: key=lambda p: p.stat().st_mtime, ) assert latest_dir.exists() - exported_model_path = str(latest_dir / "optimized_model.xml") + if task in ("visual_prompting", "zero_shot_visual_prompting"): + exported_model_path = str(latest_dir / "optimized_model_decoder.xml") + else: + exported_model_path = str(latest_dir / "optimized_model.xml") # 6) test optimized model tmp_path_test = run_cli_test(export_test_recipe, exported_model_path, Path("outputs") / "nncf_ptq", "cpu") @@ -276,8 +280,8 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: msg = f"Recipe: {recipe}, (torch_accuracy, ov_accuracy): {torch_acc} , {ov_acc}" log.info(msg) - # Not compare w/ instance segmentation because training isn't able to be deterministic, which can lead to unstable test result. - if "maskrcnn_efficientnetb2b" in recipe: + # Not compare w/ instance segmentation and visual prompting tasks because training isn't able to be deterministic, which can lead to unstable test result. + if "maskrcnn_efficientnetb2b" in recipe or task in ("visual_prompting", "zero_shot_visual_prompting"): return if "multi_label_cls/mobilenet_v3_large_light" in request.node.name: diff --git a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py index bab6fe692f9..95578b3de69 100644 --- a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py +++ b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py @@ -65,7 +65,7 @@ def test_get_prompt_candidates(self, mocker, prompt_getter, result_point_selecti mocker.patch.object(prompt_getter, "_point_selection", return_value=(result_point_selection, torch.zeros(1, 2))) image_embeddings = torch.ones(1, 4, 4, 4) reference_feats = torch.rand(1, 1, 4) - used_indices = torch.as_tensor([[0]]) + used_indices = torch.as_tensor([0]) ori_shape = torch.tensor([prompt_getter.image_size, prompt_getter.image_size], dtype=torch.int64) total_points_scores, total_bg_coords = prompt_getter.get_prompt_candidates( @@ -415,6 +415,25 @@ def test_inspect_overlapping_areas(self, mocker, build_zero_shot_segment_anythin assert all(torch.tensor([2, 2, 0.5]) == used_points[0][0]) assert all(torch.tensor([0, 0, 0.7]) == used_points[1][2]) + def test_predict_masks(self, mocker, build_zero_shot_segment_anything) -> None: + """Test _predict_masks.""" + mocker.patch( + "otx.algo.visual_prompting.segment_anything.SegmentAnything.forward", + return_value=(torch.ones(1, 4, 8, 8), torch.tensor([[0.1, 0.2, 0.5, 0.7]]), torch.ones(1, 4, 4, 4)), + ) + + zero_shot_segment_anything = build_zero_shot_segment_anything() + zero_shot_segment_anything.image_size = 6 + + mask = zero_shot_segment_anything._predict_masks( + mode="infer", + image_embeddings=torch.rand(1), + point_coords=torch.rand(1, 2, 2), + point_labels=torch.randint(low=0, high=2, size=(1, 2)), + ori_shape=torch.tensor([8, 8], dtype=torch.int64), + ) + assert mask.shape == (8, 8) + @pytest.mark.parametrize( ("masks", "logits", "expected"), [ @@ -451,6 +470,50 @@ def test_create_model(self, model) -> None: assert isinstance(zero_shot_segment_anything, torch.nn.Module) assert zero_shot_segment_anything.__class__.__name__ == "ZeroShotSegmentAnything" + @pytest.mark.parametrize("training", [True, False]) + def test_forward(self, mocker, model, training: bool) -> None: + """Test forward.""" + mocker_learn = mocker.patch.object(model, "learn") + mocker_infer = mocker.patch.object(model, "infer") + model.training = training + + model.forward(None) + + if training: + mocker_learn.assert_called_once() + else: + mocker_infer.assert_called_once() + + @pytest.mark.parametrize("reset_feat", [True, False]) + def test_learn(self, mocker, model, reset_feat: bool) -> None: + """Test learn.""" + mocker_initialize_reference_info = mocker.patch.object(model, "initialize_reference_info") + mocker_learn = mocker.patch.object(model.model, "learn") + mocker_customize_inputs = mocker.patch.object(model, "_customize_inputs") + mocker_customize_outputs = mocker.patch.object(model, "_customize_outputs") + + model.learn(None, reset_feat=reset_feat) + + if reset_feat: + mocker_initialize_reference_info.assert_called_once() + else: + mocker_initialize_reference_info.assert_not_called() + mocker_learn.assert_called_once() + mocker_customize_inputs.assert_called_once() + mocker_customize_outputs.assert_called_once() + + def test_infer(self, mocker, model) -> None: + """Test infer.""" + mocker_infer = mocker.patch.object(model.model, "infer") + mocker_customize_inputs = mocker.patch.object(model, "_customize_inputs") + mocker_customize_outputs = mocker.patch.object(model, "_customize_outputs") + + model.infer(None) + + mocker_infer.assert_called_once() + mocker_customize_inputs.assert_called_once() + mocker_customize_outputs.assert_called_once() + @pytest.mark.parametrize("is_training", [True, False]) def test_customize_inputs_learn( self, diff --git a/tests/unit/core/conftest.py b/tests/unit/core/conftest.py index 2ab8eff2a6c..bd9cfa6117b 100644 --- a/tests/unit/core/conftest.py +++ b/tests/unit/core/conftest.py @@ -1,10 +1,22 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations import pytest +import torch from otx.core.config import register_configs from otx.core.data.dataset.base import LabelInfo +from otx.core.data.entity.base import ImageInfo, Points from otx.core.data.entity.classification import HLabelData +from otx.core.data.entity.visual_prompting import ( + VisualPromptingBatchDataEntity, + VisualPromptingBatchPredEntity, + VisualPromptingDataEntity, + ZeroShotVisualPromptingBatchDataEntity, + ZeroShotVisualPromptingBatchPredEntity, + ZeroShotVisualPromptingDataEntity, +) +from torchvision import tv_tensors @pytest.fixture(scope="session", autouse=True) @@ -77,3 +89,108 @@ def fxt_hlabel_multilabel_info() -> HLabelData: ["5", "2"], ], ) + + +@pytest.fixture(scope="session") +def fxt_vpm_data_entity() -> ( + tuple[VisualPromptingDataEntity, VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity] +): + img_size = (1024, 1024) + fake_image = tv_tensors.Image(torch.rand(img_size)) + fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) + fake_bboxes = tv_tensors.BoundingBoxes( + [[0, 0, 1, 1]], + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=img_size, + dtype=torch.float32, + ) + fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) + fake_masks = tv_tensors.Mask(torch.rand(img_size)) + fake_labels = {"bboxes": torch.as_tensor([1], dtype=torch.int64)} + fake_polygons = [None] + # define data entity + single_data_entity = VisualPromptingDataEntity( + image=fake_image, + img_info=fake_image_info, + masks=fake_masks, + labels=fake_labels, + polygons=fake_polygons, + bboxes=fake_bboxes, + points=fake_points, + ) + batch_data_entity = VisualPromptingBatchDataEntity( + batch_size=1, + images=[fake_image], + imgs_info=[fake_image_info], + masks=[fake_masks], + labels=[fake_labels], + polygons=[fake_polygons], + bboxes=[fake_bboxes], + points=[fake_points], + ) + batch_pred_data_entity = VisualPromptingBatchPredEntity( + batch_size=1, + images=[fake_image], + imgs_info=[fake_image_info], + masks=[fake_masks], + labels=[fake_labels], + polygons=[fake_polygons], + bboxes=[fake_bboxes], + points=[fake_points], + scores=[], + ) + + return single_data_entity, batch_data_entity, batch_pred_data_entity + + +@pytest.fixture(scope="session") +def fxt_zero_shot_vpm_data_entity() -> ( + tuple[ + ZeroShotVisualPromptingDataEntity, + ZeroShotVisualPromptingBatchDataEntity, + ZeroShotVisualPromptingBatchPredEntity, + ] +): + img_size = (1024, 1024) + fake_image = tv_tensors.Image(torch.rand(img_size)) + fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) + fake_bboxes = tv_tensors.BoundingBoxes( + [[0, 0, 1, 1]], + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=img_size, + dtype=torch.float32, + ) + fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) + fake_masks = tv_tensors.Mask(torch.rand(img_size)) + fake_labels = torch.as_tensor([1], dtype=torch.int64) + fake_polygons = [None] + # define data entity + single_data_entity = ZeroShotVisualPromptingDataEntity( + image=fake_image, + img_info=fake_image_info, + masks=fake_masks, + labels=fake_labels, + polygons=fake_polygons, + prompts=[fake_bboxes, fake_points], + ) + batch_data_entity = ZeroShotVisualPromptingBatchDataEntity( + batch_size=1, + images=[fake_image], + imgs_info=[fake_image_info], + masks=[fake_masks], + labels=[fake_labels], + polygons=[fake_polygons], + prompts=[[fake_bboxes, fake_points]], + ) + batch_pred_data_entity = ZeroShotVisualPromptingBatchPredEntity( + batch_size=1, + images=[fake_image], + imgs_info=[fake_image_info], + masks=[fake_masks], + labels=[fake_labels], + polygons=[fake_polygons], + prompts=[[fake_bboxes, fake_points]], + scores=[], + ) + + return single_data_entity, batch_data_entity, batch_pred_data_entity diff --git a/tests/unit/core/model/entity/test_visual_prompting.py b/tests/unit/core/model/entity/test_visual_prompting.py new file mode 100644 index 00000000000..7373b31fb11 --- /dev/null +++ b/tests/unit/core/model/entity/test_visual_prompting.py @@ -0,0 +1,620 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Unit tests for visual prompting model entity.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock + +import numpy as np +import pytest +import torch +from otx.core.data.entity.visual_prompting import VisualPromptingBatchPredEntity +from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter +from otx.core.model.entity.visual_prompting import ( + OTXVisualPromptingModel, + OVVisualPromptingModel, + OVZeroShotVisualPromptingModel, +) +from torchvision import tv_tensors + + +class TestOTXVisualPromptingModel: + @pytest.fixture() + def otx_visual_prompting_model(self, mocker) -> OTXVisualPromptingModel: + mocker.patch.object(OTXVisualPromptingModel, "_create_model") + return OTXVisualPromptingModel(num_classes=1) + + def test_exporter(self, otx_visual_prompting_model) -> None: + """Test _exporter.""" + assert isinstance(otx_visual_prompting_model._exporter, OTXVisualPromptingModelExporter) + + def test_export_parameters(self, otx_visual_prompting_model) -> None: + """Test _export_parameters.""" + otx_visual_prompting_model.model.image_size = 1024 + + export_parameters = otx_visual_prompting_model._export_parameters + + assert export_parameters["input_size"] == (1, 3, 1024, 1024) + assert export_parameters["resize_mode"] == "fit_to_window" + assert export_parameters["mean"] == (123.675, 116.28, 103.53) + assert export_parameters["std"] == (58.395, 57.12, 57.375) + + def test_optimization_config(self, otx_visual_prompting_model) -> None: + """Test _optimization_config.""" + optimization_config = otx_visual_prompting_model._optimization_config + + assert optimization_config == { + "model_type": "transformer", + "advanced_parameters": { + "activations_range_estimator_params": { + "min": { + "statistics_type": "QUANTILE", + "aggregator_type": "MIN", + "quantile_outlier_prob": "1e-4", + }, + "max": { + "statistics_type": "QUANTILE", + "aggregator_type": "MAX", + "quantile_outlier_prob": "1e-4", + }, + }, + }, + } + + +class TestOVVisualPromptingModel: + @pytest.fixture() + def set_ov_visual_prompting_model(self, mocker): + def ov_visual_prompting_model(for_create_model: bool = False) -> OVVisualPromptingModel: + if for_create_model: + mocker.patch("openvino.model_api.adapters.create_core") + mocker.patch("openvino.model_api.adapters.get_user_config") + mocker.patch("openvino.model_api.adapters.OpenvinoAdapter") + mocker.patch("openvino.model_api.models.Model.create_model") + else: + mocker.patch.object( + OVVisualPromptingModel, + "_create_model", + return_value={"image_encoder": Mock(), "decoder": Mock()}, + ) + return OVVisualPromptingModel(num_classes=0, model_name="exported_model_decoder.xml") + + return ov_visual_prompting_model + + def test_create_model(self, set_ov_visual_prompting_model) -> None: + """Test _create_model.""" + ov_visual_prompting_model = set_ov_visual_prompting_model(for_create_model=True) + ov_models = ov_visual_prompting_model._create_model() + + assert isinstance(ov_models, dict) + assert "image_encoder" in ov_models + assert "decoder" in ov_models + + def test_forward(self, mocker, set_ov_visual_prompting_model, fxt_vpm_data_entity) -> None: + """Test forward.""" + ov_visual_prompting_model = set_ov_visual_prompting_model() + mocker.patch.object( + ov_visual_prompting_model.model["image_encoder"], + "preprocess", + return_value=(np.zeros((1, 3, 1024, 1024)), {}), + ) + mocker.patch.object( + ov_visual_prompting_model.model["image_encoder"], + "infer_sync", + return_value={"image_embeddings": np.random.random((1, 256, 64, 64))}, + ) + mocker.patch.object( + ov_visual_prompting_model.model["decoder"], + "preprocess", + return_value=[ + { + "point_coords": np.array([1, 1]).reshape(-1, 1, 2), + "point_labels": np.array([1], dtype=np.float32).reshape(-1, 1), + "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), + "has_mask_input": np.zeros((1, 1), dtype=np.float32), + "orig_size": np.array([1024, 1024], dtype=np.int64).reshape(-1, 2), + "label": 1, + }, + ], + ) + mocker.patch.object( + ov_visual_prompting_model.model["decoder"], + "infer_sync", + return_value={ + "iou_predictions": 0.0, + "upscaled_masks": np.zeros((1, 1, 1024, 1024), dtype=np.float32), + }, + ) + mocker.patch.object( + ov_visual_prompting_model.model["decoder"], + "postprocess", + return_value={ + "hard_prediction": np.zeros((1, 1, 1024, 1024), dtype=np.float32), + "soft_prediction": np.zeros((1, 1, 1024, 1024), dtype=np.float32), + "scores": np.zeros((1, 1), dtype=np.float32), + }, + ) + + results = ov_visual_prompting_model(fxt_vpm_data_entity[1]) + + assert isinstance(results, VisualPromptingBatchPredEntity) + assert isinstance(results.images, list) + assert isinstance(results.images[0], tv_tensors.Image) + assert isinstance(results.masks, list) + assert isinstance(results.masks[0], tv_tensors.Mask) + + def test_optimize(self, tmpdir, mocker, set_ov_visual_prompting_model) -> None: + """Test optimize.""" + mocker.patch("openvino.Core.read_model") + mocker.patch("openvino.save_model") + mocker.patch("nncf.quantize") + + ov_visual_prompting_model = set_ov_visual_prompting_model() + fake_data_module = Mock() + + results = ov_visual_prompting_model.optimize(tmpdir, fake_data_module) + + assert "image_encoder" in results + assert "decoder" in results + + +class TestOVZeroShotVisualPromptingModel: + @pytest.fixture() + def ov_zero_shot_visual_prompting_model(self, mocker) -> OVZeroShotVisualPromptingModel: + mocker.patch.object( + OVZeroShotVisualPromptingModel, + "_create_model", + return_value={"image_encoder": Mock(), "decoder": Mock()}, + ) + mocker.patch.object(OVZeroShotVisualPromptingModel, "initialize_reference_info") + return OVZeroShotVisualPromptingModel(num_classes=0, model_name="exported_model_decoder.xml") + + def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: + """Test learn.""" + ov_zero_shot_visual_prompting_model.reference_feats = np.zeros((0, 1, 256), dtype=np.float32) + ov_zero_shot_visual_prompting_model.used_indices = np.array([], dtype=np.int64) + ov_zero_shot_visual_prompting_model.model["decoder"].mask_threshold = 0.0 + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["image_encoder"], + "preprocess", + return_value=(np.zeros((1, 3, 1024, 1024)), {"original_shape": (1024, 1024)}), + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["image_encoder"], + "infer_sync", + return_value={"image_embeddings": np.random.random((1, 256, 64, 64))}, + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["decoder"], + "preprocess", + return_value=[ + { + "point_coords": np.array([1, 1]).reshape(-1, 1, 2), + "point_labels": np.array([1], dtype=np.float32).reshape(-1, 1), + "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), + "has_mask_input": np.zeros((1, 1), dtype=np.float32), + "orig_size": np.array([1024, 1024], dtype=np.int64).reshape(-1, 2), + "label": np.array([1], dtype=np.int64), + }, + ], + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["decoder"], + "infer_sync", + return_value={ + "iou_predictions": np.array([[0.1, 0.3, 0.5, 0.7]]), + "upscaled_masks": np.random.randn(1, 4, 1024, 1024), # noqa: NPY002 + "low_res_masks": np.zeros((1, 4, 64, 64), dtype=np.float32), + }, + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["decoder"], + "postprocess", + return_value={ + "hard_prediction": np.zeros((1, 1, 1024, 1024), dtype=np.float32), + "soft_prediction": np.zeros((1, 1, 1024, 1024), dtype=np.float32), + "scores": np.zeros((1, 1), dtype=np.float32), + }, + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model, + "_generate_masked_features", + return_value=np.random.rand(1, 256), # noqa: NPY002 + ) + reference_info, ref_masks = ov_zero_shot_visual_prompting_model.learn( + inputs=fxt_zero_shot_vpm_data_entity[1], + reset_feat=True, + ) + + assert reference_info["reference_feats"].shape == torch.Size((2, 1, 256)) + assert 1 in reference_info["used_indices"] + assert ref_masks[0].shape == torch.Size((2, 1024, 1024)) + + def test_infer(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: + """Test infer.""" + ov_zero_shot_visual_prompting_model.model["decoder"].mask_threshold = 0.0 + ov_zero_shot_visual_prompting_model.model["decoder"].output_blob_name = "upscaled_masks" + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["image_encoder"], + "preprocess", + return_value=(np.zeros((1, 3, 1024, 1024)), {"original_shape": (1024, 1024)}), + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["image_encoder"], + "infer_sync", + return_value={"image_embeddings": np.random.random((1, 256, 64, 64))}, + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["decoder"], + "preprocess", + return_value=[ + { + "point_coords": np.array([1, 1]).reshape(-1, 1, 2), + "point_labels": np.array([1], dtype=np.float32).reshape(-1, 1), + "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), + "has_mask_input": np.zeros((1, 1), dtype=np.float32), + "orig_size": np.array([1024, 1024], dtype=np.int64).reshape(-1, 2), + "label": np.array([1], dtype=np.int64), + }, + ], + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["decoder"], + "infer_sync", + return_value={ + "iou_predictions": np.array([[0.1, 0.3, 0.5, 0.7]]), + "upscaled_masks": np.random.randn(1, 4, 1024, 1024), # noqa: NPY002 + "low_res_masks": np.zeros((1, 4, 64, 64), dtype=np.float32), + }, + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["decoder"], + "postprocess", + return_value={ + "hard_prediction": np.zeros((1, 1, 1024, 1024), dtype=np.float32), + "soft_prediction": np.zeros((1, 1, 1024, 1024), dtype=np.float32), + "scores": np.zeros((1, 1), dtype=np.float32), + }, + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model.model["decoder"], + "apply_coords", + return_value=np.array([[1, 1], [2, 2]]), + ) + mocker.patch.object( + ov_zero_shot_visual_prompting_model, + "_get_prompt_candidates", + return_value=({1: np.array([[1, 1, 0.5]])}, {1: np.array([[2, 2]])}), + ) + + reference_feats = torch.rand(2, 1, 256) + used_indices = np.array([1]) + + results = ov_zero_shot_visual_prompting_model.infer( + inputs=fxt_zero_shot_vpm_data_entity[1], + reference_feats=reference_feats, + used_indices=used_indices, + ) + + for predicted_masks, used_points in results: + for label, predicted_mask in predicted_masks.items(): + for pm, _ in zip(predicted_mask, used_points[label]): + assert pm.shape == (1024, 1024) + + def test_gather_prompts_with_labels(self, ov_zero_shot_visual_prompting_model) -> None: + """Test _gather_prompts_with_labels.""" + batch_prompts = [ + [ + {"bboxes": "bboxes", "label": 1}, + {"points": "points", "label": 2}, + ], + ] + + processed_prompts = ov_zero_shot_visual_prompting_model._gather_prompts_with_labels(batch_prompts) + + for prompts in processed_prompts: + for label, prompt in prompts.items(): + if label == 1: + assert "bboxes" in prompt[0] + else: + assert "points" in prompt[0] + assert prompt[0]["label"] == label + + def test_initialize_reference_info(self, ov_zero_shot_visual_prompting_model) -> None: + """Test initialize_reference_info.""" + ov_zero_shot_visual_prompting_model.reference_feats = np.zeros((0, 1, 256), dtype=np.float32) + ov_zero_shot_visual_prompting_model.used_indices = np.array([], dtype=np.int64) + + assert ov_zero_shot_visual_prompting_model.reference_feats.shape == (0, 1, 256) + assert ov_zero_shot_visual_prompting_model.used_indices.shape == (0,) + + @pytest.mark.parametrize("new_largest_label", [0, 3]) + def test_expand_reference_info(self, ov_zero_shot_visual_prompting_model, new_largest_label: int) -> None: + """Test expand_reference_info.""" + ov_zero_shot_visual_prompting_model.reference_feats = np.zeros((0, 1, 256)) + + ov_zero_shot_visual_prompting_model.expand_reference_info( + new_largest_label=new_largest_label, + ) + + assert len(ov_zero_shot_visual_prompting_model.reference_feats) == new_largest_label + 1 + + def test_generate_masked_features(self, ov_zero_shot_visual_prompting_model) -> None: + """Test _generate_masked_features.""" + feats = np.random.random((8, 8, 1)) + masks = np.zeros((16, 16), dtype=np.float32) + masks[4:12, 4:12] = 1.0 + + masked_feat = ov_zero_shot_visual_prompting_model._generate_masked_features( + feats=feats, + masks=masks, + threshold_mask=0.3, + image_size=16, + ) + + assert masked_feat.shape == (1, 1) + + def test_pad_to_square(self, ov_zero_shot_visual_prompting_model) -> None: + """Test _pad_to_square.""" + result = ov_zero_shot_visual_prompting_model._pad_to_square(x=np.ones((8, 8)), image_size=16) + + assert result[:8, :8].sum() == 8**2 + assert result[:8, 8:].sum() == 0 + assert result[8:, :8].sum() == 0 + assert result[8:, 8:].sum() == 0 + + def test_find_latest_reference_info(self, mocker, ov_zero_shot_visual_prompting_model) -> None: + """Test _find_latest_reference_info.""" + mocker.patch( + "otx.core.model.entity.visual_prompting.os.path.isdir", + return_value=True, + ) + + # there are some saved reference info + mocker.patch( + "otx.core.model.entity.visual_prompting.os.listdir", + return_value=["1", "2"], + ) + results = ov_zero_shot_visual_prompting_model._find_latest_reference_info(Path()) + assert results == "2" + + # there are no saved reference info + mocker.patch( + "otx.core.model.entity.visual_prompting.os.listdir", + return_value=[], + ) + results = ov_zero_shot_visual_prompting_model._find_latest_reference_info(Path()) + assert results is None + + def test_load_latest_reference_info(self, mocker, ov_zero_shot_visual_prompting_model) -> None: + """Test load_latest_reference_info.""" + ov_zero_shot_visual_prompting_model.model["decoder"].embed_dim = 256 + + # get previously saved reference info + mocker.patch.object(ov_zero_shot_visual_prompting_model, "_find_latest_reference_info", return_value="1") + mocker.patch( + "otx.core.model.entity.visual_prompting.pickle.load", + return_value={"reference_feats": np.zeros((1, 1, 256)), "used_indices": np.array([0])}, + ) + mocker.patch("otx.core.model.entity.visual_prompting.Path.open", return_value="Mocked data") + + ov_zero_shot_visual_prompting_model.load_latest_reference_info() + assert ov_zero_shot_visual_prompting_model.reference_feats.shape == (1, 1, 256) + assert ov_zero_shot_visual_prompting_model.used_indices.shape == (1,) + + # no saved reference info + mocker.patch.object(ov_zero_shot_visual_prompting_model, "_find_latest_reference_info", return_value=None) + + ov_zero_shot_visual_prompting_model.reference_feats = np.zeros((0, 1, 256), dtype=np.float32) + ov_zero_shot_visual_prompting_model.used_indices = np.array([], dtype=np.int64) + ov_zero_shot_visual_prompting_model.load_latest_reference_info() + + assert ov_zero_shot_visual_prompting_model.reference_feats.shape == (0, 1, 256) + assert ov_zero_shot_visual_prompting_model.used_indices.shape == (0,) + + @pytest.mark.parametrize( + "result_point_selection", + [np.array([[2, 2, 0.9], [1, 2, 0.8], [0, 2, 0.7], [2, 1, 0.6]]), np.array([[-1, -1, -1]])], + ) + def test_get_prompt_candidates( + self, + mocker, + ov_zero_shot_visual_prompting_model, + result_point_selection: np.ndarray, + ) -> None: + """Test get_prompt_candidates.""" + mocker.patch.object( + ov_zero_shot_visual_prompting_model, + "_point_selection", + return_value=(result_point_selection, torch.zeros(1, 2)), + ) + image_embeddings = np.ones((1, 4, 4, 4)) + reference_feats = np.random.random((1, 1, 4)) + used_indices = np.array([0]) + original_shape = np.array([3, 3], dtype=np.int64) + + total_points_scores, total_bg_coords = ov_zero_shot_visual_prompting_model._get_prompt_candidates( + image_embeddings=image_embeddings, + reference_feats=reference_feats, + used_indices=used_indices, + original_shape=original_shape, + ) + + assert total_points_scores[0].shape[0] == len(result_point_selection) + assert total_bg_coords[0].shape[0] == 1 + + @pytest.mark.parametrize( + ("mask_sim", "expected"), + [ + ( + np.arange(0.1, 1.0, 0.1).reshape(3, 3), + np.array([[2, 2, 0.9], [1, 2, 0.8], [0, 2, 0.7], [2, 1, 0.6]]), + ), + (np.zeros((3, 3)), None), + ], + ) + def test_point_selection( + self, + ov_zero_shot_visual_prompting_model, + mask_sim: np.ndarray, + expected: np.ndarray, + ) -> None: + """Test _point_selection.""" + points_scores, bg_coords = ov_zero_shot_visual_prompting_model._point_selection( + mask_sim=mask_sim, + original_shape=np.array([3, 3]), + threshold=np.array([[0.5]]), + num_bg_points=1, + ) + + if points_scores is not None: + assert np.allclose(points_scores, expected) + + def test_resize_to_original_shape(self, ov_zero_shot_visual_prompting_model) -> None: + """Test _resize_to_original_shape.""" + masks = np.random.random((8, 8)) + image_size = 6 + original_shape = np.array([8, 10], dtype=np.int64) + + resized_masks = ov_zero_shot_visual_prompting_model._resize_to_original_shape(masks, image_size, original_shape) + + assert isinstance(resized_masks, np.ndarray) + assert resized_masks.shape == (8, 10) + + def test_get_prepadded_size(self, ov_zero_shot_visual_prompting_model) -> None: + """Test _get_prepadded_size.""" + original_shape = np.array([8, 10], dtype=np.int64) + image_size = 6 + + prepadded_size = ov_zero_shot_visual_prompting_model._get_prepadded_size(original_shape, image_size) + + assert isinstance(prepadded_size, np.ndarray) + assert prepadded_size.dtype == np.int64 + assert prepadded_size.shape == (2,) + assert np.all(prepadded_size == np.array([5, 6], dtype=np.int64)) + + def test_inspect_overlapping_areas(self, mocker, ov_zero_shot_visual_prompting_model) -> None: + """Test _inspect_overlapping_areas.""" + predicted_masks = { + 0: [ + np.array( + [ + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + ], + ), + ], + 1: [ + np.array( + [ + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 1], + ], + ), + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + ], + ), + np.array( + [ + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + ], + } + used_points = { + 0: [ + np.array([0, 0, 0.5]), # to be removed + np.array([2, 2, 0.5]), + np.array([1, 4, 0.5]), + ], + 1: [ + np.array([3, 0, 0.5]), + np.array([4, 4, 0.5]), + np.array([1, 4, 0.3]), # to be removed + np.array([0, 0, 0.7]), + ], + } + + ov_zero_shot_visual_prompting_model._inspect_overlapping_areas(predicted_masks, used_points, threshold_iou=0.5) + + assert len(predicted_masks[0]) == 2 + assert len(predicted_masks[1]) == 3 + assert all(np.array([2, 2, 0.5]) == used_points[0][0]) + assert all(np.array([0, 0, 0.7]) == used_points[1][2]) + + @pytest.mark.parametrize( + ("largest", "expected_scores", "expected_ind"), + [ + (True, np.array([[3, 2], [6, 5], [9, 8]]), np.array([[2, 1], [2, 1], [2, 1]])), + (False, np.array([[1, 2], [4, 5], [7, 8]]), np.array([[0, 1], [0, 1], [0, 1]])), + ], + ) + def test_topk_numpy( + self, + ov_zero_shot_visual_prompting_model, + largest: bool, + expected_scores: np.ndarray, + expected_ind: np.ndarray, + ) -> None: + """Test _topk_numpy.""" + x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + k = 2 + axis = -1 + + scores, ind = ov_zero_shot_visual_prompting_model._topk_numpy(x, k, axis, largest) + + np.testing.assert_array_equal(scores, expected_scores) + np.testing.assert_array_equal(ind, expected_ind)