From fb3c997b332b4b27f425e26bdae5472621a1a2f6 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 11 Jun 2024 16:26:41 +0900 Subject: [PATCH 1/6] Fix `learn` for ov model --- src/otx/core/model/visual_prompting.py | 9 +- .../visual_prompting/test_openvino_models.py | 63 ++++++----- tests/unit/core/conftest.py | 2 +- .../unit/core/model/test_visual_prompting.py | 102 ++++++------------ 4 files changed, 75 insertions(+), 101 deletions(-) diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 48c855c5f3c..073f19c7314 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -132,10 +132,7 @@ def _inference_step_for_zero_shot( if _name == "mAP": # MeanAveragePrecision _preds = [ - { - k: v > 0.5 if k == "masks" else v.squeeze(1).to(model.device) if k == "labels" else v - for k, v in ett.items() - } + {k: v > 0.5 if k == "masks" else v.to(model.device) if k == "labels" else v for k, v in ett.items()} for ett in converted_entities["preds"] ] _target = converted_entities["target"] @@ -828,7 +825,7 @@ def learn( for label, input_prompts in prompts.items(): ref_mask: np.ndarray = np.zeros(original_shape, dtype=np.uint8) for inputs_decoder in input_prompts: - label = inputs_decoder.pop("label") # noqa: PLW2901 + inputs_decoder.pop("label") if "point_coords" in inputs_decoder: # bboxes and points inputs_decoder.update(image_embeddings) @@ -853,7 +850,7 @@ def learn( cur_default_threshold_reference -= 0.05 self.reference_feats[label] = ref_feat - self.used_indices: np.ndarray = np.concatenate((self.used_indices, label)) + self.used_indices: np.ndarray = np.concatenate((self.used_indices, [label])) ref_masks[label] = ref_mask reference_masks.append(ref_masks) self.used_indices = np.unique(self.used_indices) diff --git a/tests/unit/algo/visual_prompting/test_openvino_models.py b/tests/unit/algo/visual_prompting/test_openvino_models.py index 634d198670e..71eb29e90bf 100644 --- a/tests/unit/algo/visual_prompting/test_openvino_models.py +++ b/tests/unit/algo/visual_prompting/test_openvino_models.py @@ -14,14 +14,14 @@ class TestVisualPromptingImageEncoder: - def test_parameters(self): + def test_parameters(self) -> None: """Test parameters.""" params = VisualPromptingImageEncoder.parameters() assert params.get("resize_type").default_value == "fit_to_window" assert params.get("image_size").default_value == 1024 - def test_preproces(self, mocker): + def test_preproces(self, mocker) -> None: """Test preprocess.""" mocker.patch.object(ImageModel, "__init__") image_encoder = VisualPromptingImageEncoder("adapter") @@ -41,63 +41,76 @@ def test_preproces(self, mocker): class TestVisualPromptingDecoder: @pytest.fixture(autouse=True) - def setup(self, mocker): + def setup(self, mocker) -> None: mocker.patch.object(SegmentationModel, "__init__") mocker_model_adapter = mocker.Mock(spec=OpenvinoAdapter) self.decoder = VisualPromptingDecoder(mocker_model_adapter) self.decoder.image_size = 6 - def test_parameters(self): + def test_parameters(self) -> None: """Test parameters.""" params = VisualPromptingDecoder.parameters() assert isinstance(params.get("image_size"), NumericalValue) assert params.get("image_size").default_value == 1024 - def test_get_outputs(self): + def test_get_outputs(self) -> None: """Test _get_outputs.""" results = self.decoder._get_outputs() assert results == "upscaled_masks" @pytest.mark.parametrize( - ("prompts", "expected"), + ("prompts", "prompt_type", "expected"), [ ( { "bboxes": [np.array([[1, 1], [2, 2]])], "points": [], - "labels": {"bboxes": [1]}, + "labels": {"bboxes": [np.array(1)]}, "orig_size": (4, 4), }, + "bboxes", { - "point_coords": (1, 2, 2), - "point_labels": (1, 2), + "point_coords": np.array([[[1.5, 1.5], [3.0, 3.0]]]), + "point_labels": np.array([[2.0, 3.0]]), }, ), ( - {"bboxes": [], "points": [np.array([[1, 1]])], "labels": {"points": [1]}, "orig_size": (4, 4)}, { - "point_coords": (1, 1, 2), - "point_labels": (1, 1), + "bboxes": [], + "points": [np.array([[1, 1]])], + "labels": {"points": [np.array(1)]}, + "orig_size": (4, 4), + }, + "points", + { + "point_coords": np.array([[[1.5, 1.5]]]), + "point_labels": np.array([[1.0]]), }, ), ], ) - def test_preprocess(self, prompts: dict[str, Any], expected: dict[str, Any]): + def test_preprocess(self, prompts: dict[str, Any], prompt_type: str, expected: dict[str, Any]) -> None: """Test preprocess""" results = self.decoder.preprocess(prompts) assert isinstance(results, list) - assert "point_coords" in results[0] - assert results[0]["point_coords"].shape == expected["point_coords"] - assert "point_labels" in results[0] - assert results[0]["point_labels"].shape == expected["point_labels"] - assert "mask_input" in results[0] - assert "has_mask_input" in results[0] - assert "orig_size" in results[0] - - def test_apply_coords(self): + for i in range(len(results)): + assert "point_coords" in results[i] + assert np.all(results[i]["point_coords"] == expected["point_coords"]) + assert "point_labels" in results[i] + assert np.all(results[i]["point_labels"] == expected["point_labels"]) + assert "mask_input" in results[i] + assert np.all(results[i]["mask_input"] == self.decoder.mask_input) + assert "has_mask_input" in results[i] + assert np.all(results[i]["has_mask_input"] == self.decoder.has_mask_input) + assert "orig_size" in results[i] + assert np.all(results[i]["orig_size"] == prompts["orig_size"]) + assert "label" in results[i] + assert np.all(results[i]["label"] == prompts["labels"][prompt_type][i]) + + def test_apply_coords(self) -> None: """Test apply_coords.""" coords = np.array([[[1, 1], [2, 2]]]) original_size = (12, 12) @@ -114,13 +127,13 @@ def test_apply_coords(self): (3, 4, 6, (5, 6)), ], ) - def test_get_preprocess_shape(self, old_h: int, old_w: int, image_size: int, expected: tuple[int]): + def test_get_preprocess_shape(self, old_h: int, old_w: int, image_size: int, expected: tuple[int]) -> None: """Test _get_preprocess_shape.""" result = self.decoder._get_preprocess_shape(old_h, old_w, image_size) assert result == expected - def test_get_inputs(self): + def test_get_inputs(self) -> None: """Test _get_inputs.""" self.decoder.inputs = {"images": np.ones((1, 4, 4, 3))} @@ -128,7 +141,7 @@ def test_get_inputs(self): assert returned_value[0] == ["images"] - def test_postprocess(self, mocker): + def test_postprocess(self, mocker) -> None: """Test postprocess.""" self.decoder.output_blob_name = "upscaled_masks" self.decoder.mask_threshold = 0.0 diff --git a/tests/unit/core/conftest.py b/tests/unit/core/conftest.py index bfdf0ea3a8b..5cf87a37a9b 100644 --- a/tests/unit/core/conftest.py +++ b/tests/unit/core/conftest.py @@ -151,7 +151,7 @@ def fxt_zero_shot_vpm_data_entity() -> ( ) fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) fake_masks = tv_tensors.Mask(torch.ones(1, *img_size)) - fake_labels = torch.as_tensor([[1]], dtype=torch.int64) + fake_labels = torch.as_tensor([1, 2], dtype=torch.int64) fake_polygons = [None] fake_scores = torch.tensor([[1.0]]) # define data entity diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index 1e13edea7b7..e879a5c9ae2 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -377,33 +377,9 @@ def ov_zero_shot_visual_prompting_model(self, mocker, tmpdir) -> OVZeroShotVisua (dirpath / "exported_model_decoder.xml").touch() model_name = str(dirpath / "exported_model_decoder.xml") - return OVZeroShotVisualPromptingModel(num_classes=0, model_name=model_name) + ov_zero_shot_visual_prompting_model = OVZeroShotVisualPromptingModel(num_classes=0, model_name=model_name) - @pytest.mark.parametrize("training", [True, False]) - def test_forward( - self, - mocker, - ov_zero_shot_visual_prompting_model, - fxt_zero_shot_vpm_data_entity, - training: bool, - ) -> None: - """Test forward.""" - ov_zero_shot_visual_prompting_model.training = training - ov_zero_shot_visual_prompting_model.reference_feats = "reference_feats" - ov_zero_shot_visual_prompting_model.used_indices = "used_indices" - mocker_fn = mocker.patch.object(ov_zero_shot_visual_prompting_model, "learn" if training else "infer") - mocker_customize_outputs = mocker.patch.object(ov_zero_shot_visual_prompting_model, "_customize_outputs") - - ov_zero_shot_visual_prompting_model.forward(fxt_zero_shot_vpm_data_entity[1]) - - mocker_fn.assert_called_once() - mocker_customize_outputs.assert_called_once() - - def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: - """Test learn.""" - ov_zero_shot_visual_prompting_model.reference_feats = np.zeros((0, 1, 256), dtype=np.float32) - ov_zero_shot_visual_prompting_model.used_indices = np.array([], dtype=np.int64) - ov_zero_shot_visual_prompting_model.model["decoder"].mask_threshold = 0.0 + # mocking mocker.patch.object( ov_zero_shot_visual_prompting_model.model["image_encoder"], "preprocess", @@ -424,7 +400,7 @@ def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), "has_mask_input": np.zeros((1, 1), dtype=np.float32), "orig_size": np.array([1024, 1024], dtype=np.int64).reshape(-1, 2), - "label": np.array([1], dtype=np.int64), + "label": np.array(1, dtype=np.int64), }, ], ) @@ -446,6 +422,35 @@ def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ "scores": np.zeros((1, 1), dtype=np.float32), }, ) + + return ov_zero_shot_visual_prompting_model + + @pytest.mark.parametrize("training", [True, False]) + def test_forward( + self, + mocker, + ov_zero_shot_visual_prompting_model, + fxt_zero_shot_vpm_data_entity, + training: bool, + ) -> None: + """Test forward.""" + ov_zero_shot_visual_prompting_model.training = training + ov_zero_shot_visual_prompting_model.reference_feats = "reference_feats" + ov_zero_shot_visual_prompting_model.used_indices = "used_indices" + mocker_fn = mocker.patch.object(ov_zero_shot_visual_prompting_model, "learn" if training else "infer") + mocker_customize_outputs = mocker.patch.object(ov_zero_shot_visual_prompting_model, "_customize_outputs") + + ov_zero_shot_visual_prompting_model.forward(fxt_zero_shot_vpm_data_entity[1]) + + mocker_fn.assert_called_once() + mocker_customize_outputs.assert_called_once() + + def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: + """Test learn.""" + ov_zero_shot_visual_prompting_model.reference_feats = np.zeros((0, 1, 256), dtype=np.float32) + ov_zero_shot_visual_prompting_model.used_indices = np.array([], dtype=np.int64) + ov_zero_shot_visual_prompting_model.model["decoder"].mask_threshold = 0.0 + mocker.patch.object( ov_zero_shot_visual_prompting_model, "_generate_masked_features", @@ -464,48 +469,7 @@ def test_infer(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ """Test infer.""" ov_zero_shot_visual_prompting_model.model["decoder"].mask_threshold = 0.0 ov_zero_shot_visual_prompting_model.model["decoder"].output_blob_name = "upscaled_masks" - mocker.patch.object( - ov_zero_shot_visual_prompting_model.model["image_encoder"], - "preprocess", - return_value=(np.zeros((1, 3, 1024, 1024)), {"original_shape": (1024, 1024)}), - ) - mocker.patch.object( - ov_zero_shot_visual_prompting_model.model["image_encoder"], - "infer_sync", - return_value={"image_embeddings": np.random.random((1, 256, 64, 64))}, - ) - mocker.patch.object( - ov_zero_shot_visual_prompting_model.model["decoder"], - "preprocess", - return_value=[ - { - "point_coords": np.array([1, 1]).reshape(-1, 1, 2), - "point_labels": np.array([1], dtype=np.float32).reshape(-1, 1), - "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), - "has_mask_input": np.zeros((1, 1), dtype=np.float32), - "orig_size": np.array([1024, 1024], dtype=np.int64).reshape(-1, 2), - "label": np.array([1], dtype=np.int64), - }, - ], - ) - mocker.patch.object( - ov_zero_shot_visual_prompting_model.model["decoder"], - "infer_sync", - return_value={ - "iou_predictions": np.array([[0.1, 0.3, 0.5, 0.7]]), - "upscaled_masks": np.random.randn(1, 4, 1024, 1024), - "low_res_masks": np.zeros((1, 4, 64, 64), dtype=np.float32), - }, - ) - mocker.patch.object( - ov_zero_shot_visual_prompting_model.model["decoder"], - "postprocess", - return_value={ - "hard_prediction": np.zeros((1, 1, 1024, 1024), dtype=np.float32), - "soft_prediction": np.zeros((1, 1, 1024, 1024), dtype=np.float32), - "scores": np.zeros((1, 1), dtype=np.float32), - }, - ) + mocker.patch.object( ov_zero_shot_visual_prompting_model.model["decoder"], "apply_coords", From 696b9221cda771074a914b0198abe1115d0c39fe Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 11 Jun 2024 22:27:26 +0900 Subject: [PATCH 2/6] Fix integration test --- src/otx/algo/visual_prompting/zero_shot_segment_anything.py | 2 +- src/otx/core/model/visual_prompting.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 4a18464f1eb..55206fc0969 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -774,7 +774,7 @@ def _customize_outputs( # type: ignore[override] ), ) scores.append(torch.stack([p[2] for p in used_points[label]], dim=0)) - labels.append(torch.stack([LongTensor([label]) for _ in range(scores[-1].shape[0])], dim=0)) + labels.append(torch.cat([LongTensor([label]) for _ in range(scores[-1].shape[0])], dim=0)) return ZeroShotVisualPromptingBatchPredEntity( batch_size=len(outputs), diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 073f19c7314..e0a7b94ed7c 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -1035,7 +1035,7 @@ def _customize_outputs( # type: ignore[override] ) scores.append(torch.stack([torch.as_tensor(p[2]) for p in used_points[label]], dim=0).to(self.device)) labels.append( - torch.stack([torch.LongTensor([label]) for _ in range(len(scores[-1]))], dim=0).to(self.device), + torch.cat([torch.LongTensor([label]) for _ in range(len(scores[-1]))], dim=0).to(self.device), ) return ZeroShotVisualPromptingBatchPredEntity( From 6c524ff9d5319e4795dc788b828fe359445a3657 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Wed, 12 Jun 2024 15:09:21 +0900 Subject: [PATCH 3/6] Add integration tests for `infer` without predefined `reference_info` --- .../zero_shot_segment_anything.py | 64 +++++++++----- src/otx/core/model/visual_prompting.py | 87 ++++++++++++++----- .../integration/cli/test_export_inference.py | 80 +++++++++++++---- 3 files changed, 171 insertions(+), 60 deletions(-) diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 55206fc0969..8612fd2b3eb 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -886,33 +886,57 @@ def save_reference_info(self, default_root_dir: Path | str) -> None: "used_indices": self.used_indices, } # save reference info - path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt" - path_reference_info.parent.mkdir(parents=True, exist_ok=True) + self.saved_reference_info_path: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt" + self.saved_reference_info_path.parent.mkdir(parents=True, exist_ok=True) # TODO (sungchul): ticket no. 139210 - torch.save(reference_info, path_reference_info) + torch.save(reference_info, self.saved_reference_info_path) pickle.dump( {k: v.numpy() for k, v in reference_info.items()}, - path_reference_info.with_suffix(".pickle").open("wb"), - ) - log.info(f"Saved reference info at {path_reference_info}.") - - def load_reference_info(self, default_root_dir: Path | str, device: str | torch.device = "cpu") -> bool: - """Load latest reference info to be used.""" - _infer_reference_info_root: Path = ( - self.infer_reference_info_root - if self.infer_reference_info_root == self.infer_reference_info_root.absolute() - else Path(default_root_dir) / self.infer_reference_info_root + self.saved_reference_info_path.with_suffix(".pickle").open("wb"), ) + log.info(f"Saved reference info at {self.saved_reference_info_path}.") + + def load_reference_info( + self, + default_root_dir: Path | str, + device: str | torch.device = "cpu", + path_to_directly_load: Path | None = None, + ) -> bool: + """Load latest reference info to be used. - if ( - path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt" - ).is_file(): - reference_info = torch.load(path_reference_info) + Args: + default_root_dir (Path | str): Default root directory to be used + when inappropriate infer_reference_info_root is given. + device (str | torch.device): Device that reference infos will be attached. + path_to_directly_load (Path | None): Reference info path to directly be loaded. + Normally, it is obtained after `learn` which is executed when trying to do `infer` + without reference features in `on_test_start` or `on_predict_start`. + + Returns: + (bool): Whether normally loading checkpoint or not. + """ + if path_to_directly_load is not None: + # if `path_to_directly_load` is given, forcely load + reference_info = torch.load(path_to_directly_load) retval = True - log.info(f"reference info saved at {path_reference_info} was successfully loaded.") + log.info(f"reference info saved at {path_to_directly_load} was successfully loaded.") + else: - reference_info = {} - retval = False + _infer_reference_info_root: Path = ( + self.infer_reference_info_root + if self.infer_reference_info_root == self.infer_reference_info_root.absolute() + else Path(default_root_dir) / self.infer_reference_info_root + ) + + if ( + path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt" + ).is_file(): + reference_info = torch.load(path_reference_info) + retval = True + log.info(f"reference info saved at {path_reference_info} was successfully loaded.") + else: + reference_info = {} + retval = False self.register_buffer( "reference_feats", diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index e0a7b94ed7c..0d55070dd25 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -348,7 +348,11 @@ def on_test_start(self) -> None: # to set _combined_loader self.trainer._evaluation_loop.setup_data() # noqa: SLF001 self.trainer._evaluation_loop.reset() # noqa: SLF001 - self.load_reference_info(self.trainer.default_root_dir, self.device) + self.load_reference_info( + self.trainer.default_root_dir, + self.device, + path_to_directly_load=self.saved_reference_info_path, + ) def on_predict_start(self) -> None: """Load previously saved reference info.""" @@ -361,7 +365,11 @@ def on_predict_start(self) -> None: # to set _combined_loader self.trainer._evaluation_loop.setup_data() # noqa: SLF001 self.trainer._evaluation_loop.reset() # noqa: SLF001 - self.load_reference_info(self.trainer.default_root_dir, self.device) + self.load_reference_info( + self.trainer.default_root_dir, + self.device, + path_to_directly_load=self.saved_reference_info_path, + ) def on_train_epoch_start(self) -> None: """Skip on_train_epoch_start unused in zero-shot visual prompting.""" @@ -1260,12 +1268,17 @@ def save_reference_info(self, default_root_dir: Path | str) -> None: "used_indices": self.used_indices, } # save reference info - path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt" - path_reference_info.parent.mkdir(parents=True, exist_ok=True) + self.saved_reference_info_path: Path = ( + Path(default_root_dir) / self.reference_info_dir / "reference_info.pickle" + ) + self.saved_reference_info_path.parent.mkdir(parents=True, exist_ok=True) # TODO (sungchul): ticket no. 139210 - torch.save({k: torch.as_tensor(v) for k, v in reference_info.items()}, path_reference_info) - pickle.dump(reference_info, path_reference_info.with_suffix(".pickle").open("wb")) - log.info(f"Saved reference info at {path_reference_info}.") + torch.save( + {k: torch.as_tensor(v) for k, v in reference_info.items()}, + self.saved_reference_info_path.with_suffix(".pt"), + ) + pickle.dump(reference_info, self.saved_reference_info_path.open("wb")) + log.info(f"Saved reference info at {self.saved_reference_info_path}.") def _generate_masked_features( self, @@ -1319,8 +1332,40 @@ def _pad_to_square(self, x: np.ndarray, image_size: int = 1024) -> np.ndarray: ###################################### # Infer # ###################################### - def load_reference_info(self, default_root_dir: Path | str, *args, **kwargs) -> bool: - """Load latest reference info to be used.""" + def load_reference_info( + self, + default_root_dir: Path | str, + *args, + path_to_directly_load: Path | None = None, + **kwargs, + ) -> bool: + """Load latest reference info to be used. + + Args: + default_root_dir (Path | str): Default root directory to be used + when inappropriate infer_reference_info_root is given. + path_to_directly_load (Path | None): Reference info path to directly be loaded. + Normally, it is obtained after `learn` which is executed when trying to do `infer` + without reference features in `on_test_start` or `on_predict_start`. + + Returns: + (bool): Whether normally loading checkpoint or not. + """ + + def _load_and_assign_reference_info(path: Path) -> bool: + reference_info: dict[str, np.ndarray] = pickle.load(path.open("rb")) # noqa: S301 # nosec: B301 + self.reference_feats = reference_info.get( + "reference_feats", + np.zeros((0, 1, self.model["decoder"].embed_dim), dtype=np.float32), + ) + self.used_indices = reference_info.get("used_indices", np.array([], dtype=np.int64)) + log.info(f"reference info saved at {path} was successfully loaded.") + return True + + if path_to_directly_load is not None: + # if `path_to_directly_load` is given, forcely load + return _load_and_assign_reference_info(path_to_directly_load) + _infer_reference_info_root: Path = ( self.infer_reference_info_root if self.infer_reference_info_root == self.infer_reference_info_root.absolute() @@ -1330,14 +1375,8 @@ def load_reference_info(self, default_root_dir: Path | str, *args, **kwargs) -> if ( path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pickle" ).is_file(): - reference_info: dict[str, np.ndarray] = pickle.load(path_reference_info.open("rb")) # noqa: S301 # nosec: B301 - self.reference_feats = reference_info.get( - "reference_feats", - np.zeros((0, 1, self.model["decoder"].embed_dim), dtype=np.float32), - ) - self.used_indices = reference_info.get("used_indices", np.array([], dtype=np.int64)) - log.info(f"reference info saved at {path_reference_info} was successfully loaded.") - return True + return _load_and_assign_reference_info(path_reference_info) + return False def _get_prompt_candidates( @@ -1524,7 +1563,7 @@ def on_train_start(self) -> None: def on_test_start(self) -> None: """Load previously saved reference info.""" super().on_test_start() - if not self.load_reference_info(self.trainer.default_root_dir, self.device): + if not self.load_reference_info(self.trainer.default_root_dir): log.warning("No reference info found. `Learn` will be automatically executed first.") self.trainer.lightning_module.automatic_optimization = False self.trainer.fit_loop.run() @@ -1533,11 +1572,14 @@ def on_test_start(self) -> None: # to set _combined_loader self.trainer._evaluation_loop.setup_data() # noqa: SLF001 self.trainer._evaluation_loop.reset() # noqa: SLF001 - self.load_reference_info(self.trainer.default_root_dir, self.device) + self.load_reference_info( + self.trainer.default_root_dir, + path_to_directly_load=self.saved_reference_info_path, + ) def on_predict_start(self) -> None: """Load previously saved reference info.""" - if not self.load_reference_info(self.trainer.default_root_dir, self.device): + if not self.load_reference_info(self.trainer.default_root_dir): log.warning("No reference info found. `Learn` will be automatically executed first.") self.trainer.lightning_module.automatic_optimization = False self.trainer.fit_loop.run() @@ -1546,7 +1588,10 @@ def on_predict_start(self) -> None: # to set _combined_loader self.trainer._evaluation_loop.setup_data() # noqa: SLF001 self.trainer._evaluation_loop.reset() # noqa: SLF001 - self.load_reference_info(self.trainer.default_root_dir, self.device) + self.load_reference_info( + self.trainer.default_root_dir, + path_to_directly_load=self.saved_reference_info_path, + ) def on_train_epoch_start(self) -> None: """Skip on_train_epoch_start unused in zero-shot visual prompting.""" diff --git a/tests/integration/cli/test_export_inference.py b/tests/integration/cli/test_export_inference.py index d09234297c1..15af715ab61 100644 --- a/tests/integration/cli/test_export_inference.py +++ b/tests/integration/cli/test_export_inference.py @@ -1,6 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - +from __future__ import annotations import logging from pathlib import Path @@ -141,7 +141,13 @@ def test_otx_export_infer( assert len(ckpt_files) > 0 # 2) otx test - def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: str = fxt_accelerator) -> Path: + def run_cli_test( + test_recipe: str, + checkpoint_path: str, + work_dir: Path, + device: str = fxt_accelerator, + cli_override_command: list[str] | None = None, + ) -> Path: tmp_path_test = tmp_path / f"otx_test_{model_name}" command_cfg = [ "otx", @@ -159,24 +165,30 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: checkpoint_path, ] - # Zero-shot visual prompting needs to specify `infer_reference_info_root` - if task in ["zero_shot_visual_prompting"]: - try: - idx_task = checkpoint_path.split("/").index(f"otx_train_{model_name}") - except ValueError: - idx_task = checkpoint_path.split("/").index(f"otx_test_{model_name}") - - command_cfg.extend( - [ - "--model.init_args.infer_reference_info_root", - str(Path(checkpoint_path).parents[-idx_task] / f"otx_train_{model_name}/outputs/.latest/train"), - ], - ) + if cli_override_command is not None: + command_cfg.extend(cli_override_command) + run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) return tmp_path_test - tmp_path_test = run_cli_test(recipe, str(ckpt_files[-1]), Path("outputs") / "torch") + checkpoint_path: str = str(ckpt_files[-1]) + tmp_path_test = run_cli_test(recipe, checkpoint_path, Path("outputs") / "torch") + + if task in ("zero_shot_visual_prompting"): + # Check when using reference infos obtained by otx train + idx_task = checkpoint_path.split("/").index(f"otx_train_{model_name}") + infer_reference_info_root = [ + "--model.init_args.infer_reference_info_root", + str(Path(checkpoint_path).parents[-idx_task] / f"otx_train_{model_name}/outputs/.latest/train"), + ] + + tmp_path_test = run_cli_test( + recipe, + checkpoint_path, + Path("outputs") / "torch", + cli_override_command=infer_reference_info_root, + ) assert (tmp_path_test / "outputs").exists() @@ -231,6 +243,21 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: tmp_path_test = run_cli_test(export_test_recipe, exported_model_path, Path("outputs") / "openvino", "cpu") assert (tmp_path_test / "outputs").exists() + if task in ("zero_shot_visual_prompting"): + # Check when using reference infos obtained by otx train + idx_task = exported_model_path.split("/").index(f"otx_test_{model_name}") + infer_reference_info_root = [ + "--model.init_args.infer_reference_info_root", + str(Path(exported_model_path).parents[-idx_task] / f"otx_train_{model_name}/outputs/.latest/train"), + ] + tmp_path_test = run_cli_test( + export_test_recipe, + exported_model_path, + Path("outputs") / "openvino", + "cpu", + cli_override_command=infer_reference_info_root, + ) + # 5) test optimize command_cfg = [ "otx", @@ -257,12 +284,27 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: ) assert latest_dir.exists() if task in ("visual_prompting", "zero_shot_visual_prompting"): - exported_model_path = str(latest_dir / "optimized_model_decoder.xml") + optimized_model_path = str(latest_dir / "optimized_model_decoder.xml") else: - exported_model_path = str(latest_dir / "optimized_model.xml") + optimized_model_path = str(latest_dir / "optimized_model.xml") # 6) test optimized model - tmp_path_test = run_cli_test(export_test_recipe, exported_model_path, Path("outputs") / "nncf_ptq", "cpu") + tmp_path_test = run_cli_test(export_test_recipe, optimized_model_path, Path("outputs") / "nncf_ptq", "cpu") + if task in ("zero_shot_visual_prompting"): + # Check when using reference infos obtained by otx train + idx_task = optimized_model_path.split("/").index(f"otx_test_{model_name}") + infer_reference_info_root = [ + "--model.init_args.infer_reference_info_root", + str(Path(optimized_model_path).parents[-idx_task] / f"otx_train_{model_name}/outputs/.latest/train"), + ] + tmp_path_test = run_cli_test( + export_test_recipe, + optimized_model_path, + Path("outputs") / "nncf_ptq", + "cpu", + cli_override_command=infer_reference_info_root, + ) + torch_outputs_dir = tmp_path_test / "outputs" / "torch" torch_latest_dir = max( (p for p in torch_outputs_dir.iterdir() if p.is_dir() and p.name != ".latest"), From 32bf091d6026642d548db7505d2ca9281c48e29a Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Wed, 12 Jun 2024 15:12:51 +0900 Subject: [PATCH 4/6] Enable to handle `threshold` and `num_bg_points` through arguments --- src/otx/algo/visual_prompting/zero_shot_segment_anything.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 8612fd2b3eb..6f4aa166715 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -714,12 +714,16 @@ def infer( inputs: ZeroShotVisualPromptingBatchDataEntity, reference_feats: Tensor | None = None, used_indices: Tensor | None = None, + threshold: float = 0.0, + num_bg_points: int = 1, is_cascade: bool = True, ) -> ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity: """Infer to directly connect to the model.""" self.training = False outputs = self.model.infer( **self._customize_inputs(inputs, reference_feats=reference_feats, used_indices=used_indices), + threshold=threshold, + num_bg_points=num_bg_points, is_cascade=is_cascade, ) return self._customize_outputs(outputs, inputs) From 835a1b0bf6dd3f3b12e5befe2eed11c9656a45a6 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Wed, 12 Jun 2024 15:29:43 +0900 Subject: [PATCH 5/6] Fix unit tests --- tests/unit/core/model/test_visual_prompting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index e879a5c9ae2..3d13b6df1e2 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -229,6 +229,7 @@ def test_on_test_start(self, mocker, otx_zero_shot_visual_prompting_model) -> No "setup_data", ) mocker_reset = mocker.patch.object(otx_zero_shot_visual_prompting_model.trainer._evaluation_loop, "reset") + otx_zero_shot_visual_prompting_model.saved_reference_info_path = "path" otx_zero_shot_visual_prompting_model.on_test_start() @@ -246,6 +247,7 @@ def test_on_predict_start(self, mocker, otx_zero_shot_visual_prompting_model) -> "setup_data", ) mocker_reset = mocker.patch.object(otx_zero_shot_visual_prompting_model.trainer._evaluation_loop, "reset") + otx_zero_shot_visual_prompting_model.saved_reference_info_path = "path" otx_zero_shot_visual_prompting_model.on_predict_start() From 9821e8520d32fc8da564754c9b823064f7d6ce65 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Wed, 12 Jun 2024 17:22:08 +0900 Subject: [PATCH 6/6] Fix integration test --- tests/integration/cli/test_export_inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/cli/test_export_inference.py b/tests/integration/cli/test_export_inference.py index 15af715ab61..c54a49ce0b6 100644 --- a/tests/integration/cli/test_export_inference.py +++ b/tests/integration/cli/test_export_inference.py @@ -175,7 +175,7 @@ def run_cli_test( checkpoint_path: str = str(ckpt_files[-1]) tmp_path_test = run_cli_test(recipe, checkpoint_path, Path("outputs") / "torch") - if task in ("zero_shot_visual_prompting"): + if task == "zero_shot_visual_prompting": # Check when using reference infos obtained by otx train idx_task = checkpoint_path.split("/").index(f"otx_train_{model_name}") infer_reference_info_root = [ @@ -243,7 +243,7 @@ def run_cli_test( tmp_path_test = run_cli_test(export_test_recipe, exported_model_path, Path("outputs") / "openvino", "cpu") assert (tmp_path_test / "outputs").exists() - if task in ("zero_shot_visual_prompting"): + if task == "zero_shot_visual_prompting": # Check when using reference infos obtained by otx train idx_task = exported_model_path.split("/").index(f"otx_test_{model_name}") infer_reference_info_root = [ @@ -290,7 +290,7 @@ def run_cli_test( # 6) test optimized model tmp_path_test = run_cli_test(export_test_recipe, optimized_model_path, Path("outputs") / "nncf_ptq", "cpu") - if task in ("zero_shot_visual_prompting"): + if task == "zero_shot_visual_prompting": # Check when using reference infos obtained by otx train idx_task = optimized_model_path.split("/").index(f"otx_test_{model_name}") infer_reference_info_root = [