diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 82d713ecaef..7dc4114a679 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -201,6 +201,7 @@ def __init__( log.warning(f"{condition}(=False) must be set to True, changed.") kwargs[condition] = True + self.is_cascade = kwargs.pop("is_cascade", True) super().__init__(*args, **kwargs) self.prompt_getter = PromptGetter(image_size=self.image_size) @@ -362,6 +363,7 @@ def infer( reference_feats: Tensor, used_indices: set[int], ori_shapes: list[Tensor], + is_cascade: bool = False, ) -> list[list[defaultdict[int, list[Tensor]]]]: """Zero-shot inference with reference features. @@ -372,6 +374,7 @@ def infer( reference_feats (Tensor): Reference features for target prediction. used_indices (set[int]): To check which indices of reference features are validate. ori_shapes (list[Tensor]): Original image size. + is_cascade (bool): Whether use cascade inference. Defaults to False. Returns: (list[list[defaultdict[int, list[Tensor]]]]): List of predicted masks and used points. @@ -421,6 +424,7 @@ def infer( point_coords=point_coords, point_labels=point_labels, ori_shape=ori_shape, + is_cascade=is_cascade, ) predicted_masks[label].append(mask * point_score[2]) used_points[label].append(point_score) @@ -663,6 +667,7 @@ def _customize_inputs(self, inputs: ZeroShotVisualPromptingBatchDataEntity) -> d "reference_feats": self.model.reference_info["reference_feats"], "used_indices": self.model.reference_info["used_indices"], "ori_shapes": [torch.tensor(info.ori_shape) for info in inputs.imgs_info], + "is_cascade": self.model.is_cascade, } def _customize_outputs( # type: ignore[override] diff --git a/src/otx/recipe/visual_prompting/sam_tiny_vit.yaml b/src/otx/recipe/visual_prompting/sam_tiny_vit.yaml index 447db3472a5..51939139d80 100644 --- a/src/otx/recipe/visual_prompting/sam_tiny_vit.yaml +++ b/src/otx/recipe/visual_prompting/sam_tiny_vit.yaml @@ -29,6 +29,44 @@ engine: callback_monitor: val/Dice +# tmp: currently getting all callbacks is required to override specific component +callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/Dice # updated + mode: max + patience: 3 # updated + check_on_train_epoch_end: false + - class_path: lightning.pytorch.callbacks.RichProgressBar + init_args: + refresh_rate: 1 + leave: false + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: null + monitor: null + mode: max + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + filename: "checkpoints/epoch_{epoch:03d}" + - class_path: otx.algo.callbacks.iteration_timer.IterationTimer + init_args: + prog_bar: true + on_step: false + on_epoch: true + - class_path: lightning.pytorch.callbacks.RichModelSummary + init_args: + max_depth: 1 + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + log_momentum: true + - class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 5 + decay: -0.025 + data: ../_base_/data/torchvision_base.yaml overrides: max_epochs: 100 diff --git a/src/otx/recipe/visual_prompting/sam_vit_b.yaml b/src/otx/recipe/visual_prompting/sam_vit_b.yaml index 9fa8c8203be..69afe0ddb20 100644 --- a/src/otx/recipe/visual_prompting/sam_vit_b.yaml +++ b/src/otx/recipe/visual_prompting/sam_vit_b.yaml @@ -29,6 +29,44 @@ engine: callback_monitor: val/Dice +# tmp: currently getting all callbacks is required to override specific component +callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/Dice # updated + mode: max + patience: 3 # updated + check_on_train_epoch_end: false + - class_path: lightning.pytorch.callbacks.RichProgressBar + init_args: + refresh_rate: 1 + leave: false + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: null + monitor: null + mode: max + save_top_k: 1 + save_last: true + auto_insert_metric_name: false + filename: "checkpoints/epoch_{epoch:03d}" + - class_path: otx.algo.callbacks.iteration_timer.IterationTimer + init_args: + prog_bar: true + on_step: false + on_epoch: true + - class_path: lightning.pytorch.callbacks.RichModelSummary + init_args: + max_depth: 1 + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + log_momentum: true + - class_path: otx.algo.callbacks.adaptive_train_scheduling.AdaptiveTrainScheduling + init_args: + max_interval: 5 + decay: -0.025 + data: ../_base_/data/torchvision_base.yaml overrides: max_epochs: 100 diff --git a/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml b/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml index 82254fcecfb..20b40d888a4 100644 --- a/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml +++ b/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml @@ -8,6 +8,7 @@ model: freeze_mask_decoder: True default_threshold_reference: 0.3 default_threshold_target: 0.65 + is_cascade: False # options use_stability_score: False return_single_mask: False diff --git a/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml b/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml index 6aa48cae6b4..01bfbbc9f17 100644 --- a/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml +++ b/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml @@ -8,6 +8,7 @@ model: freeze_mask_decoder: True default_threshold_reference: 0.3 default_threshold_target: 0.65 + is_cascade: False # options use_stability_score: False return_single_mask: False diff --git a/tests/regression/test_regression.py b/tests/regression/test_regression.py index 47f03ae789c..cefb34bf995 100644 --- a/tests/regression/test_regression.py +++ b/tests/regression/test_regression.py @@ -562,3 +562,51 @@ def test_regression( fxt_accelerator=fxt_accelerator, tmpdir=tmpdir, ) + + +class TestZeroShotVisualPrompting(BaseTest): + # Test case parametrization for model + MODEL_TEST_CASES = [ # noqa: RUF012 + ModelTestCase(task="zero_shot_visual_prompting", name="sam_tiny_vit"), + ModelTestCase(task="zero_shot_visual_prompting", name="sam_vit_b"), + ] + # Test case parametrization for dataset + DATASET_TEST_CASES = [ + DatasetTestCase( + name="coco_car_person_medium_datumaro", + data_root=Path("zero_shot_visual_prompting/coco_car_person_medium_datumaro"), + data_format="datumaro", + num_classes=2, + extra_overrides={"max_epochs": "1"} + ), + ] + + @pytest.mark.parametrize( + "model_test_case", + MODEL_TEST_CASES, + ids=[tc.name for tc in MODEL_TEST_CASES], + ) + @pytest.mark.parametrize( + "dataset_test_case", + DATASET_TEST_CASES, + ids=[tc.name for tc in DATASET_TEST_CASES], + ) + def test_regression( + self, + model_test_case: ModelTestCase, + dataset_test_case: DatasetTestCase, + fxt_dataset_root_dir: Path, + fxt_tags: dict, + fxt_num_repeat: int, + fxt_accelerator: str, + tmpdir: pytest.TempdirFactory, + ) -> None: + self._test_regression( + model_test_case=model_test_case, + dataset_test_case=dataset_test_case, + fxt_dataset_root_dir=fxt_dataset_root_dir, + fxt_tags=fxt_tags, + fxt_num_repeat=fxt_num_repeat, + fxt_accelerator=fxt_accelerator, + tmpdir=tmpdir, + )