diff --git a/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/openvino_models.py b/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/openvino_models.py index 1bdc1a473a3..3283026252b 100644 --- a/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/openvino_models.py +++ b/src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/openvino_models.py @@ -17,14 +17,12 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple, Union -import cv2 import numpy as np from openvino.model_api.adapters.inference_adapter import InferenceAdapter from openvino.model_api.models import ImageModel, SegmentationModel from openvino.model_api.models.types import NumericalValue, StringValue from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.pipelines import ResizeLongestSide -from otx.api.utils.segmentation_utils import create_hard_prediction_from_soft_prediction class ImageEncoder(ImageModel): @@ -64,14 +62,32 @@ class PromptGetter(ImageModel): __model__ = "prompt_getter" + def __init__(self, inference_adapter, configuration=None, preload=False): + super().__init__(inference_adapter, configuration, preload) + @classmethod def parameters(cls) -> Dict[str, Any]: # noqa: D102 parameters = super().parameters() parameters.update({"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048)}) parameters.update({"sim_threshold": NumericalValue(value_type=float, default_value=0.5, min=0, max=1)}) parameters.update({"num_bg_points": NumericalValue(value_type=int, default_value=1, min=0, max=1024)}) + parameters.update( + {"default_threshold_reference": NumericalValue(value_type=float, default_value=0.3, min=-1.0, max=1.0)} + ) return parameters + def _get_inputs(self): + """Defines the model inputs for images and additional info.""" + image_blob_names, image_info_blob_names = [], [] + for name, metadata in self.inputs.items(): + if len(metadata.shape) == 4: + image_blob_names.append(name) + else: + image_info_blob_names.append(name) + if not image_blob_names: + self.raise_error("Failed to identify the input for the image: no 4D input layer found") + return image_blob_names, image_info_blob_names + class Decoder(SegmentationModel): """Decoder class for visual prompting of openvino model wrapper.""" @@ -86,6 +102,9 @@ def __init__( ): super().__init__(model_adapter, configuration, preload) + self.mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) + self.has_mask_input = np.zeros((1, 1), dtype=np.float32) + @classmethod def parameters(cls): # noqa: D102 parameters = super().parameters() @@ -94,27 +113,30 @@ def parameters(cls): # noqa: D102 return parameters def _get_outputs(self): - return "low_res_masks" + return "upscaled_masks" def preprocess(self, inputs: Dict[str, Any], meta: Dict[str, Any]) -> List[Dict[str, Any]]: """Preprocess prompts.""" processed_prompts = [] - # TODO (sungchul): process points - for bbox, label in zip(inputs["bboxes"], inputs["labels"]): - # TODO (sungchul): add condition to check whether using bbox or point - point_coords = self._apply_coords(bbox.reshape(-1, 2, 2), inputs["original_size"]) - point_labels = np.array([2, 3], dtype=np.float32).reshape((-1, 2)) - processed_prompts.append( - { - "point_coords": point_coords, - "point_labels": point_labels, - # TODO (sungchul): how to generate mask_input and has_mask_input - "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), - "has_mask_input": np.zeros((1, 1), dtype=np.float32), - "orig_size": np.array(inputs["original_size"], dtype=np.float32).reshape((-1, 2)), - "label": label, - } - ) + for prompt_name in ["bboxes", "points"]: + for prompt, label in zip(inputs.get(prompt_name), inputs["labels"].get(prompt_name, [])): + if prompt_name == "bboxes": + point_coords = self._apply_coords(prompt.reshape(-1, 2, 2), inputs["original_size"]) + point_labels = np.array([2, 3], dtype=np.float32).reshape(-1, 2) + else: + point_coords = self._apply_coords(prompt.reshape(-1, 1, 2), inputs["original_size"]) + point_labels = np.array([1], dtype=np.float32).reshape(-1, 1) + + processed_prompts.append( + { + "point_coords": point_coords, + "point_labels": point_labels, + "mask_input": self.mask_input, + "has_mask_input": self.has_mask_input, + "orig_size": np.asarray(inputs["original_size"], dtype=np.int64).reshape(-1, 2), + "label": label, + } + ) return processed_prompts def _apply_coords(self, coords: np.ndarray, original_size: Union[List[int], Tuple[int, int]]) -> np.ndarray: @@ -152,64 +174,13 @@ def postprocess(self, outputs: Dict[str, np.ndarray], meta: Dict[str, Any]) -> T Returns: hard_prediction (np.ndarray): The hard prediction. - soft_prediction (np.ndarray): Resized, cropped, and normalized soft prediction. + soft_prediction (np.ndarray): The soft prediction. """ + probability = max(min(float(outputs["scores"]), 1.0), 0.0) + hard_prediction = outputs[self.output_blob_name].squeeze() > self.mask_threshold + soft_prediction = hard_prediction * probability - def sigmoid(x): - return np.tanh(x * 0.5) * 0.5 + 0.5 # to avoid overflow - - soft_prediction = outputs[self.output_blob_name].squeeze() - soft_prediction = self.resize_and_crop(soft_prediction, meta["original_size"][0]) - soft_prediction = sigmoid(soft_prediction) meta["soft_prediction"] = soft_prediction - - hard_prediction = create_hard_prediction_from_soft_prediction( - soft_prediction=soft_prediction, - soft_threshold=self.soft_threshold, - blur_strength=self.blur_strength, - ) - - probability = max(min(float(outputs["iou_predictions"]), 1.0), 0.0) meta["label"].probability = probability return hard_prediction, soft_prediction - - def resize_and_crop(self, soft_prediction: np.ndarray, original_size: np.ndarray) -> np.ndarray: - """Resize and crop soft prediction. - - Args: - soft_prediction (np.ndarray): Predicted soft prediction with HxW shape. - original_size (np.ndarray): The original image size. - - Returns: - final_soft_prediction (np.ndarray): Resized and cropped soft prediction for the original image. - """ - resized_soft_prediction = cv2.resize( - soft_prediction, (self.image_size, self.image_size), 0, 0, interpolation=cv2.INTER_LINEAR - ) - - prepadded_size = self.get_padded_size(original_size, self.image_size).astype(np.int64) - resized_cropped_soft_prediction = resized_soft_prediction[: prepadded_size[0], : prepadded_size[1], ...] - - original_size = original_size.astype(np.int64) - h, w = original_size - final_soft_prediction = cv2.resize( - resized_cropped_soft_prediction, (w, h), 0, 0, interpolation=cv2.INTER_LINEAR - ) - return final_soft_prediction - - def get_padded_size(self, original_size: np.ndarray, longest_side: int) -> np.ndarray: - """Get padded size from original size and longest side of the image. - - Args: - original_size (np.ndarray): The original image size with shape Bx2. - longest_side (int): The size of the longest side. - - Returns: - transformed_size (np.ndarray): The transformed image size with shape Bx2. - """ - original_size = original_size.astype(np.float32) - scale = longest_side / np.max(original_size) - transformed_size = scale * original_size - transformed_size = np.floor(transformed_size + 0.5).astype(np.int64) - return transformed_size diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/callbacks/inference.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/callbacks/inference.py index 1dc39b7cc3f..df751eeaba5 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/callbacks/inference.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/callbacks/inference.py @@ -57,7 +57,7 @@ def on_predict_epoch_end(self, _trainer: Trainer, _pl_module: LightningModule, o for output in outputs[0]: pred_masks.append(output["masks"][0]) iou_predictions.append(output["iou_predictions"][0]) - pred_labels.append(output["labels"][0]) + pred_labels.append(output["labels"][0].get("bboxes", []) + output["labels"][0].get("points", [])) for dataset_item, pred_mask, iou_prediction, labels in zip( self.otx_dataset, pred_masks, iou_predictions, pred_labels diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py index 476a2c09d69..dcf336776b1 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/dataset.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions # and limitations under the License. -from typing import Any, Dict, List, Optional, Union +from collections import defaultdict +from typing import Any, DefaultDict, Dict, List, Optional, Union import cv2 import numpy as np @@ -130,13 +131,6 @@ def generate_bbox_from_mask(gt_mask: np.ndarray, width: int, height: int) -> Lis return generate_bbox(x_min, y_min, x_max, y_max, width, height) -def generate_point_from_mask(gt_mask: np.ndarray) -> np.ndarray: - """Randomly generate point from given mask.""" - candidates = np.where(gt_mask == 1) - index = np.random.permutation(len(candidates))[0] - return candidates[index] - - class OTXVisualPromptingDataset(Dataset): """Visual Prompting Dataset Adaptor. @@ -149,14 +143,33 @@ class OTXVisualPromptingDataset(Dataset): """ def __init__( - self, dataset: DatasetEntity, image_size: int, mean: List[float], std: List[float], offset_bbox: int = 0 + self, + mode: Subset, + dataset: DatasetEntity, + image_size: int, + mean: List[float], + std: List[float], + offset_bbox: int = 0, + use_point: bool = False, + use_bbox: bool = False, ) -> None: - + self.mode = mode self.dataset = dataset self.transform = get_transform(image_size, mean, std) self.offset_bbox = offset_bbox self.labels = dataset.get_labels() + if not use_bbox and not use_point: + # if both are False, use bbox as default + use_bbox = True + self.prob = 1.0 # if using only bbox prompt + if use_bbox and use_point: + # if using both prompts, divide prob into both + self.prob = 0.5 + if not use_bbox and use_point: + # if using only point prompt + self.prob = 0.0 + def __len__(self) -> int: """Get size of the dataset. @@ -166,21 +179,28 @@ def __len__(self) -> int: return len(self.dataset) @staticmethod - def get_prompts(dataset_item: DatasetItemEntity, dataset_labels: List[LabelEntity]) -> Dict[str, Any]: + def get_prompts( + dataset_item: DatasetItemEntity, + dataset_labels: List[LabelEntity], + prob: float = 1.0, + mode: Subset = Subset.TESTING, + ) -> Dict[str, Any]: """Get propmts from dataset_item. Args: dataset_item (DatasetItemEntity): Dataset item entity. dataset_labels (List[LabelEntity]): Label information. + prob (float): Probability of which prompts will be generated. + mode (Subset): To check which mode is used between training, validation, and testing. Returns: Dict[str, Any]: Processed prompts with ground truths. """ width, height = dataset_item.width, dataset_item.height - bboxes: List[List[int]] = [] - points: List = [] # TBD + bboxes: List[np.ndarray] = [] + points: List[np.ndarray] = [] gt_masks: List[np.ndarray] = [] - labels: List[ScoredLabel] = [] + labels: DefaultDict[str, List[ScoredLabel]] = defaultdict(list) for annotation in dataset_item.get_annotations(labels=dataset_labels, include_empty=False, preserve_id=True): if isinstance(annotation.shape, Image): # use mask as-is @@ -192,25 +212,36 @@ def get_prompts(dataset_item: DatasetItemEntity, dataset_labels: List[LabelEntit continue if gt_mask.sum() == 0: - # pass no gt + # pass no gt or very small region continue - gt_masks.append(gt_mask) - - # generate bbox based on gt_mask - bbox = generate_bbox_from_mask(gt_mask, width, height) - bboxes.append(bbox) - - # TODO (sungchul): generate random points from gt_mask - # add labels - labels.extend(annotation.get_labels(include_empty=False)) + gt_masks.append(gt_mask) - bboxes = np.array(bboxes) + mask_points = np.nonzero(gt_mask) + if np.random.rand() < prob: + # generate bbox based on gt_mask + bbox = generate_bbox_from_mask(gt_mask, width, height) + bboxes.append(bbox) + labels["bboxes"].extend(annotation.get_labels(include_empty=False)) + else: + # generate point based on gt_mask + if mode == Subset.TRAINING: + # get random point from the mask + idx_chosen = np.random.permutation(len(mask_points[0]))[0] # noqa: NPY002 + point = np.array([mask_points[1][idx_chosen], mask_points[0][idx_chosen]]) + else: + # get averaged point + point = np.array([mask_points[1].mean(), mask_points[0].mean()]) + points.append(point) + labels["points"].extend(annotation.get_labels(include_empty=False)) + + bboxes = np.array(bboxes, dtype=np.float32) if len(bboxes) > 0 else np.zeros((0, 4), dtype=np.float32) + points = np.array(points, dtype=np.float32) if len(points) > 0 else np.zeros((0, 2), dtype=np.float32) return dict( original_size=np.array((height, width), dtype=np.int64), gt_masks=gt_masks, bboxes=bboxes, - points=points, # TODO (sungchul): update point information + points=points, labels=labels, ) @@ -226,7 +257,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]: dataset_item = self.dataset[index] item: Dict[str, Union[int, Tensor]] = {"index": index, "images": dataset_item.numpy} - prompts = self.get_prompts(dataset_item, self.labels) + prompts = self.get_prompts(dataset_item, self.labels, self.prob, self.mode) if len(prompts["gt_masks"]) == 0: return { "images": [], @@ -238,7 +269,6 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]: "labels": [], } - prompts["bboxes"] = np.array(prompts["bboxes"]) item.update({**prompts, "path": dataset_item.media.path}) item = self.transform(item) return item @@ -247,20 +277,6 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]: class OTXZeroShotVisualPromptingDataset(OTXVisualPromptingDataset): """Visual Prompting for Zero-shot learning Dataset Adaptor.""" - def __init__( - self, - dataset: DatasetEntity, - image_size: int, - mean: List[float], - std: List[float], - generate_point: bool = False, - generate_bbox: bool = False, - **kwargs, - ) -> None: - super().__init__(dataset, image_size, mean, std, offset_bbox=0) - self.generate_point = generate_point - self.generate_bbox = generate_bbox - def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]: """Get dataset item. @@ -273,9 +289,9 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]: dataset_item = self.dataset[index] item: Dict[str, Union[int, Tensor]] = {"index": index, "images": dataset_item.numpy} - prompts = self.get_prompts(dataset_item, self.labels) # , self.generate_point, self.generate_bbox) + prompts = self.get_prompts(dataset_item, self.labels, self.prob) item.update({**prompts, "path": dataset_item.media.path}) - item = self.transform(item) + return item @@ -314,12 +330,12 @@ def __init__( ) self.config["train_batch_size"] = 1 - self.kwargs.update( - { - "generate_point": self.config.get("generate_point", False), - "generate_bbox": self.config.get("generate_bbox", False), - } - ) + self.kwargs.update( + { + "use_point": self.config.get("use_point", False), + "use_bbox": self.config.get("use_bbox", False), + } + ) self.train_otx_dataset: DatasetEntity self.val_otx_dataset: DatasetEntity @@ -340,6 +356,7 @@ def setup(self, stage: Optional[str] = None) -> None: std = self.config.normalize.std if stage == "fit" or stage is None: self.train_dataset = self.DATASETS[self.train_type]( + mode=Subset.TRAINING, dataset=self.dataset.get_subset(Subset.TRAINING), image_size=image_size, mean=mean, @@ -351,17 +368,32 @@ def setup(self, stage: Optional[str] = None) -> None: # self.val_dataset = None if self.train_type == TrainType.Incremental: self.val_dataset = self.DATASETS[self.train_type]( - dataset=self.dataset.get_subset(Subset.VALIDATION), image_size=image_size, mean=mean, std=std + mode=Subset.VALIDATION, + dataset=self.dataset.get_subset(Subset.VALIDATION), + image_size=image_size, + mean=mean, + std=std, + **self.kwargs, ) if stage == "test": self.test_dataset = self.DATASETS[self.train_type]( - dataset=self.dataset.get_subset(Subset.TESTING), image_size=image_size, mean=mean, std=std + mode=Subset.TESTING, + dataset=self.dataset.get_subset(Subset.TESTING), + image_size=image_size, + mean=mean, + std=std, + **self.kwargs, ) if stage == "predict": self.predict_dataset = self.DATASETS[self.train_type]( - dataset=self.dataset, image_size=image_size, mean=mean, std=std, **self.kwargs + mode=Subset.TESTING, + dataset=self.dataset.get_subset(Subset.TESTING), + image_size=image_size, + mean=mean, + std=std, + **self.kwargs, ) def summary(self): @@ -375,58 +407,66 @@ def summary(self): num_items, ) - def train_dataloader(self) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]: + def train_dataloader(self) -> DataLoader: """Train Dataloader. Returns: - Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]: Train dataloader. + DataLoader: Train dataloader. """ return DataLoader( self.train_dataset, shuffle=True, batch_size=self.config.train_batch_size, num_workers=self.config.num_workers, - collate_fn=collate_fn, + collate_fn=collate_fn + if self.train_type != TrainType.Zeroshot + else lambda x: x, # type: ignore[return-value] ) - def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def val_dataloader(self) -> DataLoader: """Validation Dataloader. Returns: - Union[DataLoader, List[DataLoader]]: Validation Dataloader. + DataLoader: Validation Dataloader. """ return DataLoader( self.val_dataset, shuffle=False, batch_size=self.config.val_batch_size, num_workers=self.config.num_workers, - collate_fn=collate_fn, + collate_fn=collate_fn + if self.train_type != TrainType.Zeroshot + else lambda x: x, # type: ignore[return-value] ) - def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def test_dataloader(self) -> DataLoader: """Test Dataloader. Returns: - Union[DataLoader, List[DataLoader]]: Test Dataloader. + DataLoader: Test Dataloader. """ return DataLoader( self.test_dataset, shuffle=False, batch_size=self.config.test_batch_size, num_workers=self.config.num_workers, - collate_fn=collate_fn, + collate_fn=collate_fn + if self.train_type != TrainType.Zeroshot + else lambda x: x, # type: ignore[return-value] ) - def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def predict_dataloader(self) -> DataLoader: """Predict Dataloader. Returns: - Union[DataLoader, List[DataLoader]]: Predict Dataloader. + DataLoader: Predict Dataloader. """ return DataLoader( self.predict_dataset, shuffle=False, batch_size=1, num_workers=self.config.num_workers, - collate_fn=collate_fn, + collate_fn=collate_fn + if self.train_type != TrainType.Zeroshot + else lambda x: x, # type: ignore[return-value] ) diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/sam_transforms.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/sam_transforms.py index 06d04ea817d..63a58b9229e 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/sam_transforms.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/sam_transforms.py @@ -36,9 +36,11 @@ def __call__(self, item: Dict[str, Union[List, Tensor]]) -> Dict[str, Union[List item["images"] = torch.as_tensor( self.apply_image(item["images"], self.target_length).transpose((2, 0, 1)), dtype=torch.get_default_dtype() ) - item["gt_masks"] = [torch.as_tensor(gt_mask) for gt_mask in item["gt_masks"]] - item["bboxes"] = self.apply_boxes(item["bboxes"], item["original_size"], self.target_length) - if item["points"]: + if "gt_masks" in item: + item["gt_masks"] = [torch.as_tensor(gt_mask) for gt_mask in item["gt_masks"]] + if "bboxes" in item: + item["bboxes"] = self.apply_boxes(item["bboxes"], item["original_size"], self.target_length) + if "points" in item: item["points"] = self.apply_coords(item["points"], item["original_size"], self.target_length) return item @@ -78,9 +80,9 @@ def apply_coords( old_h, old_w = original_size new_h, new_w = cls.get_preprocess_shape(original_size[0], original_size[1], target_length) if isinstance(coords, np.ndarray): - coords = coords.astype(float) + coords = coords.astype(np.float32) else: - coords = coords.to(torch.float) + coords = coords.to(torch.float32) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) return coords diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py index f53fb4b3457..dd1abddf740 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/transforms.py @@ -34,16 +34,15 @@ def _convert_empty_to_none(x: str, dtype: torch.dtype = torch.float32) -> List: List: List of batch data. """ func = torch.stack if x == "gt_masks" else torch.tensor - items = [func(item[x]).to(dtype) for item in batch if item[x] is not None] - return None if len(items) == 0 else items + items = [func(item[x]).to(dtype) if len(item[x]) > 0 else None for item in batch] + return items index = [item["index"] for item in batch] images = torch.stack([item["images"] for item in batch]) bboxes = _convert_empty_to_none("bboxes") - points = None # TBD + points = _convert_empty_to_none("points") gt_masks = _convert_empty_to_none("gt_masks", torch.int32) original_size = _convert_empty_to_none("original_size") - padding = [item["padding"] for item in batch] path = [item["path"] for item in batch] labels = [item["labels"] for item in batch] if gt_masks: @@ -56,7 +55,6 @@ def _convert_empty_to_none(x: str, dtype: torch.dtype = torch.float32) -> List: "original_size": original_size, "path": path, "labels": labels, - "padding": padding, } return { "index": -1, @@ -67,7 +65,6 @@ def _convert_empty_to_none(x: str, dtype: torch.dtype = torch.float32) -> List: "original_size": [], "path": [], "labels": [], - "padding": [], } @@ -89,7 +86,6 @@ def __call__(self, item: Dict[str, Union[List[Any], Tensor]]) -> Dict[str, Union pad_h = max_dim - h padding = (0, 0, pad_w, pad_h) - item["padding"] = padding item["images"] = pad(item["images"], padding, fill=0, padding_mode="constant") return item diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/encoders/sam_image_encoder.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/encoders/sam_image_encoder.py index f823593d592..6944754c660 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/encoders/sam_image_encoder.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/encoders/sam_image_encoder.py @@ -5,7 +5,7 @@ # from omegaconf import DictConfig -from torch import Tensor, nn +from torch import nn from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.backbones import ( build_tiny_vit, @@ -20,25 +20,13 @@ class SAMImageEncoder(nn.Module): config (DictConfig): Config for image encoder. """ - def __init__(self, config: DictConfig): - super().__init__() + def __new__(cls, config: DictConfig): + """Initialize SAM image encoder to the target backbone.""" if "tiny_vit" == config.backbone: - self.backbone = build_tiny_vit(config.image_size) + return build_tiny_vit(config.image_size) elif "vit" in config.backbone: - self.backbone = build_vit(config.backbone, config.image_size) + return build_vit(config.backbone, config.image_size) else: raise NotImplementedError( (f"{config.backbone} for image encoder of SAM is not implemented yet. " f"Use vit_b, l, or h.") ) - - def forward(self, images: Tensor) -> Tensor: - """Forward function of image encoder. - - Args: - images (Tensor): Input tensor. - - Returns: - image_embeddings (Tensor): Output tensor. - """ - image_embeddings = self.backbone(images) - return image_embeddings diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py index 9581d21ab41..9327f7f2b84 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py @@ -6,11 +6,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# -import re + from collections import OrderedDict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple import torch from omegaconf import DictConfig @@ -133,37 +132,19 @@ def set_metrics(self) -> None: ) ) - def load_checkpoint( - self, - state_dict: Optional[OrderedDict] = None, - revise_keys: List = [(r"^image_encoder.", r"image_encoder.backbone.")], - ) -> None: + def load_checkpoint(self, state_dict: Optional[OrderedDict] = None) -> None: """Load checkpoint for SAM. Args: state_dict (Optional[OrderedDict], optional): State dict of SAM. Defaults to None. - revise_keys (List, optional): List of tuples of regex patterns to revise keys of state_dict. - Defaults to [(r'^image_encoder.', r'image_encoder.backbone.')]. """ - - def replace_state_dict_keys(state_dict, revise_keys): - for p, r in revise_keys: - state_dict = OrderedDict( - { - re.sub(p, r, k) if re.search(p, k) and not re.search(r, k) else k: v - for k, v in state_dict.items() - } - ) - return state_dict - if state_dict: # state_dict from args.load_from - state_dict = replace_state_dict_keys(state_dict, revise_keys) self.load_state_dict(state_dict) elif self.config.model.checkpoint: if str(self.config.model.checkpoint).endswith(".ckpt"): # load lightning checkpoint - self.load_from_checkpoint(self.config.model.checkpoint) + self.load_from_checkpoint(self.config.model.checkpoint, strict=False) else: if str(self.config.model.checkpoint).startswith("http"): # get checkpoint from url @@ -172,8 +153,12 @@ def replace_state_dict_keys(state_dict, revise_keys): # load checkpoint from local with open(self.config.model.checkpoint, "rb") as f: state_dict = torch.load(f) - state_dict = replace_state_dict_keys(state_dict, revise_keys) + self.load_state_dict(state_dict, strict=False) + else: + # use default checkpoint + state_dict = torch.hub.load_state_dict_from_url(CKPT_PATHS[self.config.model.backbone]) + self.load_state_dict(state_dict, strict=False) ########################################################## # forward for inference (export/deploy/optimize) # @@ -186,7 +171,7 @@ def forward( point_labels: Tensor, mask_input: Tensor, has_mask_input: Tensor, - # orig_size: Tensor, + orig_size: Tensor, ): """Forward method for SAM inference (export/deploy). @@ -228,18 +213,16 @@ def forward( if self.config.model.return_single_mask: masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) - return scores, masks - # TODO (sungchul): apply inner postprocessing - # upscaled_masks = self.mask_postprocessing(masks, orig_size[0]) + upscaled_masks = self.postprocess_masks(masks, self.config.model.image_size, orig_size[0]) - # if self.config.model.return_extra_metrics: - # stability_scores = self.calculate_stability_score( - # upscaled_masks, self.config.model.mask_threshold, self.config.model.stability_score_offset - # ) - # areas = (upscaled_masks > self.config.model.mask_threshold).sum(-1).sum(-1) - # return upscaled_masks, scores, stability_scores, areas, masks + if self.config.model.return_extra_metrics: + stability_scores = self.calculate_stability_score( + upscaled_masks, self.config.model.mask_threshold, self.config.model.stability_score_offset + ) + areas = (upscaled_masks > self.config.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks - # return upscaled_masks, scores, masks + return upscaled_masks, scores, masks def _embed_points(self, point_coords: Tensor, point_labels: Tensor) -> Tensor: """Embed sparse input prompts. @@ -334,9 +317,9 @@ def select_masks(self, masks: Tensor, iou_preds: Tensor, num_points: int) -> Tup return masks, iou_preds - @staticmethod - def mask_postprocessing(masks: Tensor, input_size: int, orig_size: Tensor) -> Tensor: - """Postprocesses the predicted masks. + @classmethod + def postprocess_masks(cls, masks: Tensor, input_size: int, orig_size: Tensor) -> Tensor: + """Postprocess the predicted masks. Args: masks (Tensor): A batch of predicted masks with shape Bx1xHxW. @@ -347,22 +330,20 @@ def mask_postprocessing(masks: Tensor, input_size: int, orig_size: Tensor) -> Te Returns: masks (Tensor): The postprocessed masks with shape Bx1xHxW. """ - - def resize_longest_image_size(input_image_size: Tensor, longest_side: int) -> Tensor: - scale = longest_side / torch.max(input_image_size) - transformed_size = scale * input_image_size - transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) - return transformed_size - masks = F.interpolate(masks, size=(input_size, input_size), mode="bilinear", align_corners=False) - prepadded_size = resize_longest_image_size(orig_size, input_size) - masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + prepadded_size = cls.get_prepadded_size(cls, orig_size, input_size) # type: ignore[arg-type] + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] orig_size = orig_size.to(torch.int64) h, w = orig_size[0], orig_size[1] - masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) - return masks + return F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + + def get_prepadded_size(self, input_image_size: Tensor, longest_side: int) -> Tensor: + """Get pre-padded size.""" + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + return torch.floor(transformed_size + 0.5).to(torch.int64) ###################################################### # forward for training/validation/prediction # @@ -395,23 +376,32 @@ def forward_train( image_embeddings = self.image_encoder(images) pred_masks = [] ious = [] - for embedding, bbox in zip(image_embeddings, bboxes): - sparse_embeddings, dense_embeddings = self.prompt_encoder( - points=points, - boxes=bbox, - masks=masks, - ) + for idx, embedding in enumerate(image_embeddings): + low_res_masks, iou_predictions = [], [] + for idx_prompt, prompt in enumerate([bboxes[idx], points[idx]]): + if prompt is None: + continue + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=(prompt.unsqueeze(1), torch.ones(len(prompt), 1, device=prompt.device)) + if idx_prompt == 1 + else None, + boxes=prompt if idx_prompt == 0 else None, + masks=None, + ) - low_res_masks, iou_predictions = self.mask_decoder( - image_embeddings=embedding.unsqueeze(0), - image_pe=self.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=False, # when given multiple prompts. if there is single prompt True would be better. - ) + _low_res_masks, _iou_predictions = self.mask_decoder( + image_embeddings=embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, # when given multiple prompts. if there is single prompt True would be better. # noqa: E501 + ) + low_res_masks.append(_low_res_masks) + iou_predictions.append(_iou_predictions) - pred_masks.append(low_res_masks) - ious.append(iou_predictions) + pred_masks.append(torch.cat(low_res_masks, dim=0)) + ious.append(torch.cat(iou_predictions, dim=0)) return pred_masks, ious @@ -441,10 +431,8 @@ def training_step(self, batch, batch_idx) -> Tensor: num_masks = sum(len(pred_mask) for pred_mask in pred_masks) for i, (pred_mask, gt_mask, iou_prediction) in enumerate(zip(pred_masks, gt_masks, iou_predictions)): - pred_mask = self.postprocess_masks( - pred_mask, images.shape[2:], batch["padding"][i], batch["original_size"][i] - ) - pred_mask = pred_mask.sigmoid() + pred_mask = self.postprocess_masks(pred_mask, self.config.model.image_size, batch["original_size"][i]) + pred_mask = pred_mask.sigmoid().squeeze(1) self.train_metrics["train_IoU"].update(pred_mask, gt_mask) self.train_metrics["train_F1"].update(pred_mask, gt_mask) self.train_metrics["train_Dice"].update(pred_mask, gt_mask) @@ -499,10 +487,8 @@ def validation_step(self, batch, batch_idx) -> MetricCollection: pred_masks, _ = self.forward_train(images, bboxes, points) for i, (pred_mask, gt_mask) in enumerate(zip(pred_masks, gt_masks)): - pred_mask = self.postprocess_masks( - pred_mask, images.shape[2:], batch["padding"][i], batch["original_size"][i] - ) - pred_mask = pred_mask.sigmoid() + pred_mask = self.postprocess_masks(pred_mask, self.config.model.image_size, batch["original_size"][i]) + pred_mask = pred_mask.sigmoid().squeeze(1) for k, v in self.val_metrics.items(): v.update(pred_mask, gt_mask) @@ -532,41 +518,15 @@ def predict_step(self, batch, batch_idx) -> Dict[str, Tensor]: masks: List[Tensor] = [] for i, pred_mask in enumerate(pred_masks): - mask = self.postprocess_masks(pred_mask, images.shape[2:], batch["padding"][i], batch["original_size"][i]) + mask = self.postprocess_masks(pred_mask, self.config.model.image_size, batch["original_size"][i]) if not self.config.model.return_logits: mask = (mask > self.config.model.mask_threshold).to(mask.dtype) else: mask = mask.sigmoid() - masks.append(mask) + masks.append(mask.squeeze(1)) return dict(masks=masks, iou_predictions=iou_predictions, path=batch["path"], labels=batch["labels"]) - @staticmethod - def postprocess_masks( - masks: Tensor, - input_size: Tuple[int, int], - padding: Union[Tuple[int, ...], Tensor], - original_size: Union[Tuple[int, int], Tensor], - ) -> Tensor: - """Remove padding and upscale masks to the original image size. - - Args: - masks (Tensor): Predicted masks from the mask_decoder with (N, 1, H/downsized_ratio, W/downsized_ratio). - input_size (tuple(int, int)): The size of the image input to the model, in (H, W) format. - Used to remove padding. - padding (tuple(int, int, int, int), Tensor): The padding applied to the image before input to the model, - in (left, top, right, bottom) format. - original_size (tuple(int, int), Tensor): The original size of the image before resizing - for input to the model, in (H, W) format. - - Returns: - (Tensor): Postprocessed masks in NxHxW format, where (H, W) is given by original_size. - """ - masks = F.interpolate(masks, input_size, mode="bilinear", align_corners=False) - masks = masks[..., : input_size[0] - padding[3], : input_size[1] - padding[2]] - masks = F.interpolate(masks, [int(o) for o in original_size], mode="bilinear", align_corners=False) - return masks.squeeze(1) - def configure_optimizers(self) -> optim: """Configure the optimizer for SAM. diff --git a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/zero_shot_segment_anything.py b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/zero_shot_segment_anything.py index 14c8e5dd6f2..5948f2e42bb 100644 --- a/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/zero_shot_segment_anything.py +++ b/src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/zero_shot_segment_anything.py @@ -3,19 +3,24 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import json +import os +import pickle +import time from collections import OrderedDict, defaultdict from copy import deepcopy from itertools import product from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Union +import cv2 +import numpy as np import torch from omegaconf import DictConfig -from torch import nn +from torch import Tensor, nn +from torch.nn import Parameter, ParameterDict from torch.nn import functional as F -from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.pipelines import ( - ResizeLongestSide, -) +from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.dataset import get_transform from otx.api.entities.scored_label import ScoredLabel from otx.utils.logger import get_logger @@ -30,114 +35,71 @@ class PromptGetter(nn.Module): default_threshold_reference = 0.3 default_threshold_target = 0.65 - def __init__( - self, - image_size: int, - reference_feats: Optional[torch.Tensor] = None, - reference_prompts: Optional[torch.Tensor] = None, - downsizing: int = 64, - ) -> None: + def __init__(self, image_size: int, downsizing: int = 64) -> None: super().__init__() self.image_size = image_size self.downsizing = downsizing - self.initialize(reference_feats, reference_prompts) - - self.zero_tensor = torch.tensor(0) - - def initialize( - self, reference_feats: Optional[torch.Tensor] = None, reference_prompts: Optional[torch.Tensor] = None - ) -> None: - """Initialize reference features and prompts.""" - self.reference_feats = reference_feats - self.reference_prompts = reference_prompts def set_default_thresholds(self, default_threshold_reference: float, default_threshold_target: float) -> None: """Set default thresholds.""" self.default_threshold_reference = default_threshold_reference self.default_threshold_target = default_threshold_target - def set_reference(self, label: ScoredLabel, reference_feats: torch.Tensor, reference_prompts: torch.Tensor) -> None: - """Set reference features and prompts.""" - if self.reference_feats is None: - self.reference_feats = torch.zeros_like(reference_feats).unsqueeze(0) - if self.reference_prompts is None: - self.reference_prompts = torch.zeros_like(reference_prompts).unsqueeze(0) - - for idx in range(int(label.id_) + 1): - if idx == int(label.id_): - while self.reference_feats.shape[0] - 1 < idx: - self.reference_feats = torch.cat( - (self.reference_feats, torch.zeros_like(reference_feats).unsqueeze(0)), dim=0 - ) - self.reference_prompts = torch.cat( - (self.reference_prompts, torch.zeros_like(reference_prompts).unsqueeze(0)), dim=0 - ) - self.reference_feats[idx] = reference_feats - self.reference_prompts[idx] = reference_prompts - - def forward( + def get_prompt_candidates( self, - image_embeddings: torch.Tensor, - original_size: torch.Tensor, - threshold: torch.Tensor = torch.tensor([[0.0]], dtype=torch.float32), - num_bg_points: torch.Tensor = torch.tensor([[1]], dtype=torch.int64), - ) -> Tuple[torch.Tensor, torch.Tensor]: + image_embeddings: Tensor, + reference_feats: Tensor, + used_indices: Tensor, + original_size: Tensor, + threshold: Tensor = torch.tensor([[0.0]], dtype=torch.float32), + num_bg_points: Tensor = torch.tensor([[1]], dtype=torch.int64), + device: Union[torch.device, str] = torch.device("cpu"), + ) -> Tuple[Dict[int, Tensor], Dict[int, Tensor]]: """Get prompt candidates.""" - total_points_scores: torch.Tensor - total_bg_coords: torch.Tensor - - device = image_embeddings.device threshold = threshold.to(device) - for label in torch.arange(self.reference_feats.shape[0]): - points_scores, bg_coords = self.get_prompt_candidates( + + total_points_scores: Dict[int, Tensor] = {} + total_bg_coords: Dict[int, Tensor] = {} + for label in map(int, used_indices[0]): + points_scores, bg_coords = self( image_embeddings=image_embeddings, - label=label, + reference_feat=reference_feats[label], original_size=original_size, threshold=threshold, num_bg_points=num_bg_points, - device=device, ) - if label == 0: - total_points_scores = points_scores.unsqueeze(0) - total_bg_coords = bg_coords.unsqueeze(0) - else: - pad_size = torch.tensor(points_scores.shape[0] - total_points_scores.shape[1]) - pad_tot = torch.max(self.zero_tensor, pad_size) - pad_cur = torch.max(self.zero_tensor, -pad_size) - total_points_scores = F.pad(total_points_scores, (0, 0, 0, pad_tot, 0, 0), value=-1) - points_scores = F.pad(points_scores, (0, 0, 0, pad_cur), value=-1) - - total_points_scores = torch.cat((total_points_scores, points_scores.unsqueeze(0)), dim=0) - total_bg_coords = torch.cat((total_bg_coords, bg_coords.unsqueeze(0)), dim=0) + total_points_scores[label] = points_scores + total_bg_coords[label] = bg_coords return total_points_scores, total_bg_coords - def get_prompt_candidates( + def forward( self, - image_embeddings: torch.Tensor, - label: torch.Tensor, - original_size: torch.Tensor, - threshold: torch.Tensor = torch.tensor([[0.0]], dtype=torch.float32), - num_bg_points: torch.Tensor = torch.tensor([[1]], dtype=torch.int64), - device: torch.device = torch.device("cpu"), - ) -> Tuple[torch.Tensor, torch.Tensor]: + image_embeddings: Tensor, + reference_feat: Tensor, + original_size: Tensor, + threshold: Tensor = torch.tensor([[0.0]], dtype=torch.float32), + num_bg_points: Tensor = torch.tensor([[1]], dtype=torch.int64), + ) -> Tuple[Tensor, Tensor]: """Get prompt candidates from given reference and target features.""" - assert original_size.dim() == 2 and threshold.dim() == 2 and num_bg_points.dim() == 2 + original_size = original_size.squeeze() + threshold = threshold.squeeze() + num_bg_points = num_bg_points.squeeze() target_feat = image_embeddings.squeeze() c_feat, h_feat, w_feat = target_feat.shape target_feat = target_feat / target_feat.norm(dim=0, keepdim=True) target_feat = target_feat.reshape(c_feat, h_feat * w_feat) - sim = self.reference_feats[label].to(device) @ target_feat + sim = reference_feat @ target_feat sim = sim.reshape(1, 1, h_feat, w_feat) - sim = ZeroShotSegmentAnything.mask_postprocessing(sim, self.image_size, original_size[0]) + sim = ZeroShotSegmentAnything.postprocess_masks(sim, self.image_size, original_size) threshold = (threshold == 0) * self.default_threshold_target + threshold points_scores, bg_coords = self._point_selection( mask_sim=sim[0, 0], - original_size=original_size[0], + original_size=original_size, threshold=threshold, num_bg_points=num_bg_points, ) @@ -146,16 +108,16 @@ def get_prompt_candidates( def _point_selection( self, - mask_sim: torch.Tensor, - original_size: torch.Tensor, - threshold: torch.Tensor, - num_bg_points: torch.Tensor = torch.tensor([[1]], dtype=torch.int64), - ) -> Tuple[torch.Tensor, torch.Tensor]: + mask_sim: Tensor, + original_size: Tensor, + threshold: Union[Tensor, float] = 0.0, + num_bg_points: Union[Tensor, int] = 1, + ) -> Tuple[Tensor, Tensor]: """Select point used as point prompts.""" _, w_sim = mask_sim.shape # Top-last point selection - bg_indices = mask_sim.flatten().topk(num_bg_points[0, 0], largest=False)[1] + bg_indices = mask_sim.flatten().topk(num_bg_points, largest=False)[1] bg_x = (bg_indices // w_sim).unsqueeze(0) bg_y = bg_indices - bg_x * w_sim bg_coords = torch.cat((bg_y, bg_x), dim=0).permute(1, 0) @@ -164,6 +126,10 @@ def _point_selection( point_coords = torch.where(mask_sim > threshold) fg_coords_scores = torch.stack(point_coords[::-1] + (mask_sim[point_coords],), dim=0).T + # to handle empty tensor + len_fg_coords_scores = len(fg_coords_scores) + fg_coords_scores = F.pad(fg_coords_scores, (0, 0, 0, max(0, 1 - len_fg_coords_scores)), value=-1) + ratio = self.image_size / original_size.max() width = (original_size[1] * ratio).to(torch.int64) n_w = width // self.downsizing @@ -183,10 +149,12 @@ def _point_selection( matched_grid = fg_coords_scores.unsqueeze(1) * matched_matrix.unsqueeze(-1) # sample the highest score one of the samples that are in the same grid - points_scores = matched_grid[matched_grid[..., -1].argsort(dim=0, descending=True)[0]].diagonal().T + matched_indices = matched_grid[..., -1].topk(k=1, dim=0, largest=True)[1][0].to(torch.int64) + points_scores = matched_grid[matched_indices].diagonal().T # sort by the highest score - points_scores = points_scores[torch.argsort(points_scores[:, -1], descending=True)] + sorted_points_scores_indices = torch.argsort(points_scores[:, -1], descending=True).to(torch.int64) + points_scores = points_scores[sorted_points_scores_indices] return points_scores, bg_coords @@ -194,37 +162,34 @@ def _point_selection( class ZeroShotSegmentAnything(SegmentAnything): """Zero-shot learning module using Segment Anything.""" - def __init__(self, config: Optional[DictConfig] = None, state_dict: Optional[OrderedDict] = None) -> None: + def __init__( + self, + config: Optional[DictConfig] = None, + manual_config_update: Optional[Dict] = None, + state_dict: Optional[OrderedDict] = None, + ) -> None: if config is None: config = self.set_default_config() - if not config.model.freeze_image_encoder: - logger.warning("config.model.freeze_image_encoder(=False) must be set to True, changed.") - config.model.freeze_image_encoder = True + if ( + manual_config_update is not None + and isinstance(manual_config_update, dict) + and len(manual_config_update) > 0 + ): + for k, v in manual_config_update.items(): + exec(f"config.{k} = {v}") - if not config.model.freeze_prompt_encoder: - logger.warning("config.model.freeze_prompt_encoder(=False) must be set to True, changed.") - config.model.freeze_prompt_encoder = True - - if not config.model.freeze_mask_decoder: - logger.warning("config.model.freeze_mask_decoder(=False) must be set to True, changed.") - config.model.freeze_mask_decoder = True - - prompt_getter_reference_feats = None - prompt_getter_reference_prompts = None - if state_dict: - if "prompt_getter.reference_feats" in state_dict: - prompt_getter_reference_feats = state_dict.pop("prompt_getter.reference_feats") - if "prompt_getter.reference_prompts" in state_dict: - prompt_getter_reference_prompts = state_dict.pop("prompt_getter.reference_prompts") + # check freeze conditions + for condition in ["freeze_image_encoder", "freeze_prompt_encoder", "freeze_mask_decoder"]: + if not getattr(config.model, condition, False): + logger.warning(f"config.model.{condition}(=False) must be set to True, changed.") + setattr(config.model, condition, True) super().__init__(config, state_dict) - self.prompt_getter = PromptGetter( - image_size=config.model.image_size, - reference_feats=prompt_getter_reference_feats, - reference_prompts=prompt_getter_reference_prompts, - ) + self.set_empty_reference_info() + + self.prompt_getter = PromptGetter(image_size=config.model.image_size) self.prompt_getter.set_default_thresholds( default_threshold_reference=config.model.default_threshold_reference, default_threshold_target=config.model.default_threshold_target, @@ -233,6 +198,25 @@ def __init__(self, config: Optional[DictConfig] = None, state_dict: Optional[Ord self.point_labels_box = torch.tensor([[2, 3]], dtype=torch.float32) self.has_mask_inputs = [torch.tensor([[0.0]]), torch.tensor([[1.0]])] + self.transforms = get_transform( + image_size=config.model.image_size, mean=config.dataset.normalize.mean, std=config.dataset.normalize.std + ) + + self.path_reference_info = "vpm_zsl_reference_infos/{}/reference_info.pt" + + def load_state_dict_pre_hook(self, state_dict: Dict[str, Any], prefix: str = "", *args, **kwargs) -> None: + """Load reference info manually.""" + _reference_feats: Tensor = state_dict.get( + "reference_info.reference_feats", torch.tensor([], dtype=torch.float32) + ) + _used_indices: Tensor = state_dict.get("reference_info.used_indices", torch.tensor([], dtype=torch.int64)) + self.reference_info = ParameterDict( + { + "reference_feats": Parameter(_reference_feats, requires_grad=False), + "used_indices": Parameter(_used_indices, requires_grad=False), + }, + ) + def set_default_config(self) -> DictConfig: """Set default config when using independently.""" return DictConfig( @@ -247,18 +231,47 @@ def set_default_config(self) -> DictConfig: "freeze_prompt_encoder": True, "image_size": 1024, "mask_threshold": 0.0, - } + "return_single_mask": False, + "use_stability_score": False, + "stability_score_offset": 1.0, + "return_extra_metrics": False, + }, + "dataset": { + "normalize": { + "mean": [123.675, 116.28, 103.53], + "std": [58.395, 57.12, 57.375], + } + }, } ) + def set_empty_reference_info(self) -> None: + """Set empty reference information.""" + reference_feats: Parameter = Parameter(torch.tensor([], dtype=torch.float32), requires_grad=False) + used_indices: Parameter = Parameter(torch.tensor([[]], dtype=torch.int64), requires_grad=False) + self.reference_info = ParameterDict( + { + "reference_feats": reference_feats, + "used_indices": used_indices, + }, + ) + self.is_reference_info_empty = True + + def initialize_reference_info(self) -> None: + """Initialize reference information.""" + self.reference_info["reference_feats"] = Parameter(torch.zeros(0, 1, 256), requires_grad=False) + self.reference_info["used_indices"] = Parameter(torch.tensor([[]], dtype=torch.int64), requires_grad=False) + self.is_reference_info_empty = False + + def expand_reference_info(self, new_largest_label: int) -> None: + """Expand reference info dimensions if newly given processed prompts have more lables.""" + if new_largest_label > (cur_largest_label := len(self.reference_info["reference_feats"]) - 1): + diff = new_largest_label - cur_largest_label + padded_reference_feats = F.pad(self.reference_info["reference_feats"], (0, 0, 0, 0, 0, diff), value=0.0) + self.reference_info["reference_feats"] = Parameter(padded_reference_feats, requires_grad=False) + @torch.no_grad() - def learn( - self, - images: torch.Tensor, - processed_prompts: Dict[ScoredLabel, List[Dict[str, torch.Tensor]]], - padding: Union[Tuple[int, ...], torch.Tensor], - original_size: torch.Tensor, - ) -> None: + def learn(self, batch: List[Dict[str, Any]], reset_feat: bool = False) -> Union[None, Tuple[ParameterDict, Tensor]]: """Get reference features. Using given images, get reference features and save it to PromptGetter. @@ -266,111 +279,152 @@ def learn( Currently, single batch is only supported. Args: - images (torch.Tensor): Given images for reference features. - processed_prompts (Dict[ScoredLabel, List[Dict[str, torch.Tensor]]]): The whole class-wise prompts - processed at _preprocess_prompts. - padding (Union[Tuple[int, ...], torch.Tensor]): Padding size. - original_size (torch.Tensor): Original image size. + batch (List[Dict[str, Any]]): List of dictionaries containing images, prompts, and metas. + `batch` must contain images, prompts with bboxes, points, annotations, and polygons. + reset_feat (bool): Whether reset reference_info. + For OTX standalone, resetting reference_info will be conducted in on_train_start. + For other frameworks, setting it to True is required to reset reference_info. Defaults to False. + + Returns: + (Tuple[ParameterDict, Tensor]): reference_info and ref_masks. """ - assert images.shape[0] == 1, "Only single batch is supported." + if reset_feat: + self.initialize_reference_info() + + # preprocess images and prompts + transformed_batch = [self.transforms(b.copy()) for b in batch] + processed_prompts = [self._preprocess_prompts(tb) for tb in transformed_batch] + + # initialize tensors to contain reference features and prompts + largest_label = max([label for pp in processed_prompts for label in pp.keys()]) + self.expand_reference_info(largest_label) + # TODO(sungchul): consider who to handle multiple reference features, currently replace it - self.prompt_getter.initialize() + batch_ref_masks: List[Tensor] = [] + for tb, pp in zip(transformed_batch, processed_prompts): + # assign components + images = tb["images"].unsqueeze(0).to(self.device) # type: ignore[union-attr] + original_size = torch.as_tensor(tb["original_size"]) - image_embeddings = self.image_encoder(images) - ref_feat = image_embeddings.squeeze().permute(1, 2, 0) + image_embeddings = self.image_encoder(images) + processed_embedding = image_embeddings.squeeze().permute(1, 2, 0) - for label, input_prompts in processed_prompts.items(): - if label.name.lower() == "background": - # skip background + ref_masks = torch.zeros(largest_label + 1, *map(int, original_size)) + for label, input_prompts in pp.items(): # TODO (sungchul): how to skip background class - continue - # generate reference mask - # TODO (sungchul): ensemble multi reference features (current : use merged masks) - reference_prompt = torch.zeros(*map(int, original_size), dtype=torch.uint8, device=self.device) - for input_prompt in input_prompts: - if "annotation" in input_prompt: - # directly use annotation information as a mask - reference_prompt[input_prompt.get("annotation") == 1] += 1 - else: - merged_input_prompts = self._merge_prompts(label, input_prompt, processed_prompts) - # TODO (sungchul): they must be processed in `_merge_prompts` - # and it is required to be expanded to other prompts. - point_coords = [] - point_labels = [] - if "box" in merged_input_prompts: - for box in merged_input_prompts["box"]: - point_coords.append(box[:2]) - point_labels.append(2) - point_coords.append(box[2:]) - point_labels.append(3) - - if "points" in merged_input_prompts: - raise NotImplementedError() - - if "annotations" in merged_input_prompts: - raise NotImplementedError() - - point_coords = torch.stack(point_coords, dim=0).unsqueeze(0) - point_labels = torch.tensor([point_labels], device=self.device) - masks = self._predict_masks( - image_embeddings=image_embeddings, - point_coords=point_coords, - point_labels=point_labels, - original_size=original_size, - is_cascade=False, + # generate reference mask + # TODO (sungchul): ensemble multi reference features (current : use merged masks) + ref_mask = torch.zeros(*map(int, original_size), dtype=torch.uint8, device=self.device) + for input_prompt in input_prompts: + if (prompt := input_prompt.get("annotations", None)) is not None: + # directly use annotation information as a mask + ref_mask[prompt == 1] += 1 + elif (prompt := input_prompt.get("polygons", None)) is not None: + for polygon in prompt["polygons"]: + contour = [[int(point[0]), int(point[1])] for point in polygon] + mask_from_polygon = np.zeros(original_size, dtype=np.uint8) + mask_from_polygon = cv2.drawContours(mask_from_polygon, np.asarray([contour]), 0, 1, -1) + ref_mask[mask_from_polygon == 1] += 1 + elif (prompt := input_prompt.get("scribble_annotation", None)) is not None: + logger.warning("scribble_annotation is not supported yet.") + continue + elif (prompt := input_prompt.get("scribble_polygon", None)) is not None: + logger.warning("scribble_polygon is not supported yet.") + continue + else: + point_coords = [] + point_labels = [] + if (prompt := input_prompt.get("bboxes", None)) is not None: + point_coords = prompt["point_coords"].reshape(1, 2, 2) + + elif (prompt := input_prompt.get("points", None)) is not None: + point_coords = prompt["point_coords"].reshape(1, 1, 2) + + point_labels = prompt["point_labels"] + + masks = self._predict_masks( + image_embeddings=image_embeddings, + point_coords=point_coords, + point_labels=point_labels, + original_size=original_size, + is_cascade=False, + ) + ref_mask[masks] += 1 + ref_mask = torch.clip(ref_mask, 0, 1).to(torch.float32) + + ref_feat = None + default_threshold_reference = deepcopy(self.prompt_getter.default_threshold_reference) + while ref_feat is None: + logger.info(f"[*] default_threshold_reference : {default_threshold_reference:.4f}") + ref_feat = self._generate_masked_features( + processed_embedding, ref_mask, default_threshold_reference ) - reference_prompt[masks] += 1 - reference_prompt = torch.clip(reference_prompt, 0, 1) - - ref_mask = reference_prompt.to(torch.float32) - reference_feat = None - default_threshold_reference = deepcopy(self.prompt_getter.default_threshold_reference) - while reference_feat is None: - logger.info(f"[*] default_threshold_reference : {default_threshold_reference:.4f}") - reference_feat = self._generate_masked_features( - ref_feat, ref_mask, default_threshold_reference, padding=padding - ) - default_threshold_reference -= 0.05 + default_threshold_reference -= 0.05 - self.prompt_getter.set_reference(label, reference_feat, reference_prompt) + self.reference_info["reference_feats"][label] = ref_feat.detach().cpu() + self.reference_info["used_indices"] = Parameter( + torch.cat((self.reference_info["used_indices"], torch.tensor([[label]])), dim=1), + requires_grad=False, + ) + ref_masks[label] = ref_mask.detach().cpu() + batch_ref_masks.append(ref_masks) + return self.reference_info, batch_ref_masks @torch.no_grad() def infer( - self, images: torch.Tensor, original_size: torch.Tensor - ) -> List[List[DefaultDict[int, List[torch.Tensor]]]]: + self, + batch: List[Dict[str, Any]], + reference_feats: Union[np.ndarray, Tensor], + used_indices: Union[np.ndarray, Tensor], + is_cascade: bool = False, + ) -> List[List[DefaultDict[int, List[Tensor]]]]: """Zero-shot inference with reference features. Get target results by using reference features and target images' features. Args: - images (torch.Tensor): Given images for target results. - original_size (torch.Tensor): Original image size. + batch (List[Dict[str, Any]]): List of dictionaries containing images and metas. + reference_feats (Union[np.ndarray, Tensor]): Reference features for target prediction. + If it is np.ndarray, it will be converted to torch tensor. + used_indices (Union[np.ndarray, Tensor]): To check which indices of reference features are validate. + If it is np.ndarray, it will be converted to torch tensor. + is_cascade (bool): Whether use cascade inference. Defaults to False. Returns: - (List[List[DefaultDict[int, List[torch.Tensor]]]]): Target results. + (List[List[DefaultDict[int, List[Tensor]]]]): Target results. Lists wrapping results is following this order: 1. Target images 2. Tuple of predicted masks and used points gotten by point selection """ - assert images.shape[0] == 1, "Only single batch is supported." + if isinstance(reference_feats, np.ndarray): + reference_feats = torch.as_tensor(reference_feats, device=self.device) + if isinstance(used_indices, np.ndarray): + used_indices = torch.as_tensor(used_indices, device=self.device) - total_results = [] - for image in images: - if image.ndim == 3: - image = image.unsqueeze(0) + # preprocess images and prompts + transformed_batch = [self.transforms(b.copy()) for b in batch] - image_embeddings = self.image_encoder(images) + total_results: List[List[Tensor]] = [] + for tb in transformed_batch: + # assign components + images = tb["images"].unsqueeze(0).to(self.device) # type: ignore[union-attr] + original_size = torch.as_tensor(tb["original_size"]) - total_points_scores, total_bg_coords = self.prompt_getter( - image_embeddings=image_embeddings, original_size=original_size + image_embeddings = self.image_encoder(images) + total_points_scores, total_bg_coords = self.prompt_getter.get_prompt_candidates( + image_embeddings=image_embeddings, + reference_feats=reference_feats, + used_indices=used_indices, + original_size=original_size, + device=self.device, ) predicted_masks: defaultdict = defaultdict(list) used_points: defaultdict = defaultdict(list) - for label, (points_scores, bg_coords) in enumerate(zip(total_points_scores, total_bg_coords)): + for label in total_points_scores.keys(): + points_scores = total_points_scores[label] + bg_coords = total_bg_coords[label] for points_score in points_scores: - if points_score[-1] == -1: - continue x, y = points_score[:2] is_done = False for pm in predicted_masks.get(label, []): @@ -382,9 +436,7 @@ def infer( continue point_coords = torch.cat((points_score[:2].unsqueeze(0), bg_coords), dim=0).unsqueeze(0) - point_coords = ResizeLongestSide.apply_coords( - point_coords, original_size[0], self.config.model.image_size - ) + point_coords = self._preprocess_coords(point_coords, original_size, self.config.model.image_size) point_labels = torch.tensor( [1] + [0] * len(bg_coords), dtype=torch.float32, device=self.device ).unsqueeze(0) @@ -392,23 +444,24 @@ def infer( image_embeddings=image_embeddings, point_coords=point_coords, point_labels=point_labels, - original_size=original_size[0], + original_size=original_size, + is_cascade=is_cascade, ) predicted_masks[label].append((mask * points_score[2]).detach().cpu()) used_points[label].append(points_score.detach().cpu()) # check overlapping area between different label masks - self.__inspect_overlapping_areas(predicted_masks, used_points) + self._inspect_overlapping_areas(predicted_masks, used_points) total_results.append([predicted_masks, used_points]) return total_results - def __inspect_overlapping_areas( + def _inspect_overlapping_areas( self, - predicted_masks: Dict[int, List[torch.Tensor]], - used_points: Dict[int, List[torch.Tensor]], + predicted_masks: Dict[int, List[Tensor]], + used_points: Dict[int, List[Tensor]], threshold_iou: float = 0.8, - ): - def __calculate_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor): + ) -> None: + def _calculate_mask_iou(mask1: Tensor, mask2: Tensor): assert mask1.ndim == 2 and mask2.ndim == 2 intersection = torch.logical_and(mask1, mask2).sum().item() union = torch.logical_or(mask1, mask2).sum().item() @@ -426,32 +479,34 @@ def __calculate_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor): overlapped_label = [] overlapped_other_label = [] for (im, mask), (jm, other_mask) in product(enumerate(masks), enumerate(other_masks)): - if __calculate_mask_iou(mask, other_mask) > threshold_iou: + if _calculate_mask_iou(mask, other_mask) > threshold_iou: if used_points[label][im][2] > used_points[other_label][jm][2]: overlapped_other_label.append(jm) else: overlapped_label.append(im) - for im in overlapped_label[::-1]: + for im in sorted(list(set(overlapped_label)), reverse=True): masks.pop(im) used_points[label].pop(im) - for jm in overlapped_other_label[::-1]: + for jm in sorted(list(set(overlapped_other_label)), reverse=True): other_masks.pop(jm) used_points[other_label].pop(jm) def _predict_masks( self, - image_embeddings: torch.Tensor, - point_coords: torch.Tensor, - point_labels: torch.Tensor, - original_size: torch.Tensor, + image_embeddings: Tensor, + point_coords: Tensor, + point_labels: Tensor, + original_size: Tensor, is_cascade: bool = True, - ) -> torch.Tensor: + ) -> Tensor: """Predict target masks.""" - logits: torch.Tensor - scores: torch.Tensor - for i in range(3): + masks: Tensor + logits: Tensor + scores: Tensor + num_iter = 3 if is_cascade else 1 + for i in range(num_iter): if i == 0: # First-step prediction mask_input = torch.zeros(1, 1, *map(lambda x: x * 4, image_embeddings.shape[2:]), device=self.device) @@ -459,7 +514,7 @@ def _predict_masks( elif is_cascade and i == 1: # Cascaded Post-refinement-1 - mask_input, masks = self._postprocess_masks(logits, scores, original_size, is_single=True) # noqa: F821 + mask_input, masks = self._postprocess_masks(masks, logits, scores, is_single=True) # noqa: F821 if masks.sum() == 0: return masks @@ -467,129 +522,150 @@ def _predict_masks( elif is_cascade and i == 2: # Cascaded Post-refinement-2 - mask_input, masks = self._postprocess_masks(logits, scores, original_size) # noqa: F821 + mask_input, masks = self._postprocess_masks(masks, logits, scores) # noqa: F821 if masks.sum() == 0: return masks has_mask_input = self.has_mask_inputs[1].to(self.device) coords = torch.nonzero(masks) y, x = coords[:, 0], coords[:, 1] - box_coords = ResizeLongestSide.apply_coords( - torch.tensor([[[x.min(), y.min()], [x.max(), y.max()]]], dtype=torch.float32, device=self.device), + box_coords = self._preprocess_coords( + torch.as_tensor( + [[[x.min(), y.min()], [x.max(), y.max()]]], dtype=torch.float32, device=self.device + ), original_size, self.config.model.image_size, ) point_coords = torch.cat((point_coords, box_coords), dim=1) point_labels = torch.cat((point_labels, self.point_labels_box.to(self.device)), dim=1) - scores, logits = self( + high_res_masks, scores, logits = self( image_embeddings=image_embeddings, point_coords=point_coords, point_labels=point_labels, mask_input=mask_input, has_mask_input=has_mask_input, + orig_size=original_size.unsqueeze(0), ) - - _, masks = self._postprocess_masks(logits, scores, original_size) + masks = high_res_masks > self.config.model.mask_threshold + _, masks = self._postprocess_masks(masks, logits, scores) return masks def training_step(self, batch, batch_idx) -> None: """Training step for `learn`.""" - # TODO (sungchul): each prompt will be assigned with each label - bboxes = batch["bboxes"] - labels = batch["labels"] - # TODO (sungchul): support other below prompts - # points = batch["points"] - # annotations = batch["annotations"] - - # organize prompts based on label - processed_prompts = self._preprocess_prompts(bboxes=bboxes[0], labels=labels[0]) - - self.learn( - images=batch["images"], - processed_prompts=processed_prompts, - padding=batch.get("padding")[0], - original_size=batch.get("original_size")[0], - ) + self.learn(batch) def predict_step(self, batch, batch_idx): """Predict step for `infer`.""" - results = self.infer(images=batch["images"], original_size=batch.get("original_size")[0].unsqueeze(0)) + results = self.infer(batch, self.reference_info["reference_feats"], self.reference_info["used_indices"]) return [result[0] for result in results] # tmp: only mask - def _preprocess_prompts( - self, - bboxes: Optional[torch.Tensor] = None, - points: Optional[torch.Tensor] = None, - annotations: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - ) -> Dict[ScoredLabel, List[Dict[str, torch.Tensor]]]: + def _preprocess_prompts(self, batch: Dict[str, Any]) -> Dict[Any, Any]: """Preprocess prompts. Currently, preprocessing for bounding boxes is only supported. Args: - bboxes (torch.Tensor, optional): Bounding box prompts to be preprocessed. - points (torch.Tensor, optional): Point prompts to be preprocessed, to be supported. - annotations (torch.Tensor, optional): annotation prompts to be preprocessed, to be supported. - labels (torch.Tensor, optional): Assigned labels according to given prompts. - Currently, it is only matched to bboxes, and it will be deprecated. + batch (Dict[str, Any]): Dictionary containing data and prompts information. Returns: - (defaultdict[ScoredLabel, List[Dict[str, torch.Tensor]]]): Processed and arranged each single prompt + (Dict[Any, Any]): Processed and arranged each single prompt using label information as keys. Unlike other prompts, `annotation` prompts will be aggregated as single annotation. """ processed_prompts = defaultdict(list) - # TODO (sungchul): will be updated - if bboxes is not None: - for bbox, label in zip(bboxes, labels): - processed_prompts[label].append({"box": bbox.reshape(-1, 4)}) + for prompt_name in ["annotations", "polygons", "bboxes", "points"]: + prompts = batch.get(prompt_name, None) + labels = batch["labels"].get(prompt_name, None) + if prompts is None or len(prompts) == 0: + continue + for prompt, label in zip(prompts, labels): + if isinstance(label, ScoredLabel): + label = int(label.id_) + # TODO (sungchul): revisit annotations and polygons + if prompt_name == "annotations": + processed_prompts[label].append({prompt_name: torch.as_tensor(prompt, device=self.device)}) + elif prompt_name == "polygons": + masks = [] + for polygon in prompt: + contour = [[int(point[0]), int(point[1])] for point in polygon] + mask_from_polygon = np.zeros(batch["original_size"], dtype=np.uint8) + mask_from_polygon = cv2.drawContours(mask_from_polygon, np.asarray([contour]), 0, 1, -1) + masks.append(mask_from_polygon) + processed_prompts[label].append({prompt_name: torch.tensor(prompt, device=self.device)}) + elif prompt_name == "bboxes": + processed_prompts[label].append( + { + prompt_name: { + "point_coords": torch.as_tensor(prompt.reshape(-1, 2, 2), device=self.device), + "point_labels": torch.tensor([[1]], device=self.device), + } + } + ) + elif prompt_name == "points": + processed_prompts[label].append( + { + prompt_name: { + "point_coords": torch.as_tensor(prompt.reshape(-1, 2), device=self.device), + "point_labels": torch.tensor([[1]], device=self.device), + } + } + ) - if points: - pass + processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x)) # type: ignore[assignment] + return processed_prompts - if annotations: - pass + def _preprocess_coords( + self, + coords: Tensor, + ori_shape: Union[List[int], Tuple[int, int], Tensor], + target_length: int, + ) -> Tensor: + """Expects a torch tensor of length 2 in the final dimension. - processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x[0].id_)) # type: ignore[assignment] - return processed_prompts + Requires the original image size in (H, W) format. + + Args: + coords (Tensor): Coordinates tensor. + ori_shape (Union[List[int], Tuple[int, int], Tensor]): Original size of image. + target_length (int): The length of the longest side of the image. + + Returns: + (Tensor): Resized coordinates. + """ + old_h, old_w = ori_shape + new_h, new_w = self.get_prepadded_size(ori_shape, target_length) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords def _generate_masked_features( self, - feats: torch.Tensor, - masks: torch.Tensor, + feats: Tensor, + masks: Tensor, threshold_mask: float, - padding: Optional[Union[Tuple[int, ...], torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, ...]: + ) -> Tuple[Tensor, ...]: """Generate masked features. Args: - feats (torch.Tensor): Raw reference features. It will be filtered with masks. - masks (torch.Tensor): Reference masks used to filter features. + feats (Tensor): Raw reference features. It will be filtered with masks. + masks (Tensor): Reference masks used to filter features. threshold_mask (float): Threshold to control masked region. - padding (Union[Tuple[int, ...], torch.Tensor], optional): Padding size. Returns: - (torch.Tensor): Masked features. + (Tensor): Masked features. """ - if padding: - resized_size = ( - self.config.model.image_size - padding[1] - padding[3], - self.config.model.image_size - padding[0] - padding[2], - ) - else: - resized_size = (self.config.model.image_size, self.config.model.image_size) + scale_factor = self.config.model.image_size / max(masks.shape) # Post-process masks - masks = F.interpolate(masks.unsqueeze(0).unsqueeze(0), size=resized_size, mode="bilinear").squeeze() - masks = self._preprocess_masks(masks) + masks = F.interpolate(masks.unsqueeze(0).unsqueeze(0), scale_factor=scale_factor, mode="bilinear").squeeze() + masks = self._pad_to_square(masks) masks = F.interpolate(masks.unsqueeze(0).unsqueeze(0), size=feats.shape[0:2], mode="bilinear").squeeze() # Target feature extraction if (masks > threshold_mask).sum() == 0: # (for stability) there is no area to be extracted - return None, None + return None masked_feat = feats[masks > threshold_mask] masked_feat = masked_feat.mean(0).unsqueeze(0) @@ -597,16 +673,15 @@ def _generate_masked_features( return masked_feat - def _preprocess_masks(self, x: torch.Tensor) -> torch.Tensor: - """Normalize pixel values and pad to a square input. + def _pad_to_square(self, x: Tensor) -> Tensor: + """Pad to a square input. Args: - x (torch.Tensor): Mask to be padded. + x (Tensor): Mask to be padded. Returns: - (torch.Tensor): Padded mask. + (Tensor): Padded mask. """ - # Pad h, w = x.shape[-2:] padh = self.config.model.image_size - h padw = self.config.model.image_size - w @@ -615,15 +690,12 @@ def _preprocess_masks(self, x: torch.Tensor) -> torch.Tensor: def _postprocess_masks( self, - logits: torch.Tensor, - scores: torch.Tensor, - original_size: torch.Tensor, + masks: Tensor, + logits: Tensor, + scores: Tensor, is_single: bool = False, ): """Post-process masks for cascaded post-refinements.""" - high_res_masks = self.mask_postprocessing(logits, self.config.model.image_size, original_size) - masks = high_res_masks > self.config.model.mask_threshold - if is_single: best_idx = 0 else: @@ -638,65 +710,11 @@ def _postprocess_masks( if len(scores[0]) == 0: # all predicted masks were zero masks, ignore them. - return None, torch.zeros((self.config.model.image_size, self.config.model.image_size), device="cpu") + return None, torch.zeros(masks.shape[-2:], device="cpu") best_idx = torch.argmax(scores[0]) return logits[:, best_idx], masks[0, best_idx] - def _update_value(self, target: Dict[str, Any], key: str, value: torch.Tensor) -> None: - """Update tensor to target dictionary. - - Args: - target (Dict[str, Any]): Target dictionary to be updated. - key (str): Key to be used for update. - value (torch.Tensor): Value to be used for update. - """ - if key in target: - target[key] = torch.cat((target[key], value)) - else: - target[key] = value - - def _merge_prompts( - self, - label: ScoredLabel, - input_prompts: Dict[str, torch.Tensor], - processed_prompts: Dict[ScoredLabel, List[Dict[str, torch.Tensor]]], - use_only_background: bool = True, - ) -> Dict[str, torch.Tensor]: - """Merge target prompt and other prompts. - - Merge a foreground prompt and other prompts (background or prompts with other classes). - - Args: - label (ScoredLabel): Label information. Background is 0 and other foregrounds are >= 0. - input_prompts (Dict[str, torch.Tensor]): A foreground prompt to be merged with other prompts. - processed_prompts (Dict[ScoredLabel, List[Dict[str, torch.Tensor]]]): The whole class-wise prompts - processed at _preprocess_prompts. - use_only_background (bool): Whether merging only background prompt, defaults to True. - It is applied to only point_coords. - - Returns: - (Dict[str, torch.Tensor]): Merged prompts. - """ - merged_input_prompts = deepcopy(input_prompts) - for other_label, other_input_prompts in processed_prompts.items(): - if other_label.id_ == label.id_: - continue - if (use_only_background and other_label.id_ == 0) or (not use_only_background): - # only add point (and scribble) prompts - # use_only_background=True -> background prompts are only added as background - # use_only_background=False -> other prompts are added as background - for other_input_prompt in other_input_prompts: - if "point_coords" in other_input_prompt: - # point, scribble - self._update_value(merged_input_prompts, "point_coords", other_input_prompt.get("point_coords")) - self._update_value( - merged_input_prompts, - "point_labels", - torch.zeros_like(other_input_prompt.get("point_labels")), - ) - return merged_input_prompts - def set_metrics(self) -> None: """Skip set_metrics unused in zero-shot learning.""" pass @@ -705,6 +723,41 @@ def configure_optimizers(self) -> None: """Skip configure_optimizers unused in zero-shot learning.""" pass + def _find_latest_reference_info(self, root: str = "vpm_zsl_reference_infos") -> Union[str, None]: + """Find latest reference info to be used.""" + if not os.path.isdir(root): + return None + if len(stamps := sorted(os.listdir(root), reverse=True)) > 0: + return stamps[0] + return None + + def on_train_start(self) -> None: + """Called at the beginning of training after sanity check.""" + self.initialize_reference_info() + + def on_predict_start(self) -> None: + """Called at the beginning of predicting.""" + if (latest_stamp := self._find_latest_reference_info()) is not None: + latest_reference_info = self.path_reference_info.format(latest_stamp) + self.reference_info = torch.load(latest_reference_info) + self.reference_info.to(self.device) + logger.info(f"reference info saved at {latest_reference_info} was successfully loaded.") + def training_epoch_end(self, outputs) -> None: - """Skip training_epoch_end unused in zero-shot learning.""" - pass + """Called in the training loop at the very end of the epoch.""" + self.reference_info["used_indices"] = Parameter( + self.reference_info["used_indices"].unique().unsqueeze(0), requires_grad=False + ) + if self.config.model.save_outputs: + path_reference_info = self.path_reference_info.format(time.strftime("%Y%m%d-%H%M%S")) + os.makedirs(os.path.dirname(path_reference_info), exist_ok=True) + torch.save(self.reference_info, path_reference_info) + pickle.dump( + {k: v.numpy() for k, v in self.reference_info.items()}, + open(path_reference_info.replace(".pt", ".pickle"), "wb"), + ) + json.dump( + repr(self.trainer.datamodule.train_dataset.dataset), + open(path_reference_info.replace("reference_info.pt", "reference_meta.json"), "w"), + ) + logger.info(f"Saved reference info at {path_reference_info}.") diff --git a/src/otx/algorithms/visual_prompting/configs/base/configuration.py b/src/otx/algorithms/visual_prompting/configs/base/configuration.py index d7383c28c69..2961fa8a837 100644 --- a/src/otx/algorithms/visual_prompting/configs/base/configuration.py +++ b/src/otx/algorithms/visual_prompting/configs/base/configuration.py @@ -132,6 +132,15 @@ class __Postprocessing(ParameterGroup): affects_outcome_of=ModelLifecycle.INFERENCE, ) + default_threshold_reference = configurable_float( + default_value=0.3, + header="Default reference threshold", + description="The threshold to get target area in the mask for reference features.", + min_value=-1.0, + max_value=1.0, + affects_outcome_of=ModelLifecycle.INFERENCE, + ) + @attrs class __POTParameter(BaseConfig.BasePOTParameter): header = string_attribute("POT Parameters") diff --git a/src/otx/algorithms/visual_prompting/configs/sam_tiny_vit/config.yaml b/src/otx/algorithms/visual_prompting/configs/sam_tiny_vit/config.yaml index 36e9748338b..f0dd50ca827 100644 --- a/src/otx/algorithms/visual_prompting/configs/sam_tiny_vit/config.yaml +++ b/src/otx/algorithms/visual_prompting/configs/sam_tiny_vit/config.yaml @@ -15,6 +15,8 @@ dataset: - 57.12 - 57.375 offset_bbox: 20 # randomness for generating bounding box, pixel + use_point: false + use_bbox: false model: name: SAM diff --git a/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/config.yaml b/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/config.yaml index 097390fba0f..a50ea244fbd 100644 --- a/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/config.yaml +++ b/src/otx/algorithms/visual_prompting/configs/zero_shot_sam_tiny_vit/config.yaml @@ -15,8 +15,8 @@ dataset: - 57.12 - 57.375 offset_bbox: 0 - generate_point: false - generate_bbox: false + use_point: false + use_bbox: false model: name: SAM @@ -36,6 +36,7 @@ model: # zero-shot default_threshold_reference: 0.3 default_threshold_target: 0.65 + save_outputs: True # PL Trainer Args. Don't add extra parameter here. trainer: diff --git a/src/otx/algorithms/visual_prompting/tasks/inference.py b/src/otx/algorithms/visual_prompting/tasks/inference.py index 1123cd20c87..86a45131b52 100644 --- a/src/otx/algorithms/visual_prompting/tasks/inference.py +++ b/src/otx/algorithms/visual_prompting/tasks/inference.py @@ -179,7 +179,7 @@ def get_model(config: DictConfig, train_type: TrainType, state_dict: Optional[Or SegmentAnything as VisualPrompter, ) elif train_type == TrainType.Zeroshot: - from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models import ( + from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models import ( # type: ignore[assignment] # noqa: E501 ZeroShotSegmentAnything as VisualPrompter, ) @@ -305,8 +305,9 @@ def _export_to_onnx(self, onnx_path: Dict[str, str]): "point_labels": torch.randint(low=0, high=4, size=(1, 2), dtype=torch.float32), "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float32), "has_mask_input": torch.tensor([[1]], dtype=torch.float32), + "orig_size": torch.randint(low=256, high=2048, size=(1, 2), dtype=torch.int64), } - output_names = ["iou_predictions", "low_res_masks"] + output_names = ["upscaled_masks", "iou_predictions", "low_res_masks"] model_to_export = self.model with warnings.catch_warnings(): @@ -640,16 +641,19 @@ def _export_to_onnx(self, onnx_path: Dict[str, str]): model_to_export = self.model.image_encoder elif module == "visual_prompting_prompt_getter": + reference_feat = torch.randn(1, 256, dtype=torch.float32) + reference_feat /= reference_feat.norm(dim=-1, keepdim=True) dummy_inputs = { "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float32), + "reference_feat": reference_feat, "original_size": torch.randint(low=0, high=image_size * 2, size=(1, 2), dtype=torch.int64), - "threshold": torch.tensor([[0.1]], dtype=torch.float32), + "threshold": torch.tensor([[0.0]], dtype=torch.float32), "num_bg_points": torch.randint(low=1, high=image_size, size=(1, 1), dtype=torch.int64), } - output_names = ["total_points_scores", "total_bg_coords"] + output_names = ["points_scores", "bg_coords"] dynamic_axes = { - "total_points_scores": {0: "num_labels", 1: "num_points"}, - "total_bg_coords": {0: "num_labels", 1: "num_points"}, + "points_scores": {0: "num_points"}, + "bg_coords": {0: "num_points"}, } model_to_export = self.model.prompt_getter @@ -666,8 +670,9 @@ def _export_to_onnx(self, onnx_path: Dict[str, str]): "point_labels": torch.randint(low=0, high=4, size=(1, 2), dtype=torch.float32), "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float32), "has_mask_input": torch.tensor([[1]], dtype=torch.float32), + "orig_size": torch.randint(low=256, high=2048, size=(1, 2), dtype=torch.int64), } - output_names = ["iou_predictions", "low_res_masks"] + output_names = ["upscaled_masks", "iou_predictions", "low_res_masks"] model_to_export = self.model else: @@ -704,13 +709,8 @@ def save_model(self, output_model: ModelEntity) -> None: logger.info("Saving the model weights and reference features.") model_info = self.model.state_dict() - # TODO (sungchul): is there more efficient way not to manually add properties? - model_info.update( - { - "prompt_getter.reference_feats": self.model.prompt_getter.reference_feats, - "prompt_getter.reference_prompts": self.model.prompt_getter.reference_prompts, - } - ) + model_info.pop("reference_info.reference_feats") + model_info.pop("reference_info.used_indices") buffer = io.BytesIO() torch.save(model_info, buffer) diff --git a/src/otx/algorithms/visual_prompting/tasks/openvino.py b/src/otx/algorithms/visual_prompting/tasks/openvino.py index 0f29f40d17d..02faaa4ee3e 100644 --- a/src/otx/algorithms/visual_prompting/tasks/openvino.py +++ b/src/otx/algorithms/visual_prompting/tasks/openvino.py @@ -17,16 +17,19 @@ import io import json import os +import pickle import random import tempfile import time from collections import defaultdict +from copy import deepcopy from itertools import product from pathlib import Path from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Type, Union from zipfile import ZipFile import attr +import cv2 import nncf import numpy as np import openvino.runtime as ov @@ -149,11 +152,19 @@ def __init__( self.transform = get_transform() # TODO (sungchul): insert args def pre_process( - self, dataset_item: DatasetItemEntity, extra_processing: bool = False + self, + dataset_item: DatasetItemEntity, + extra_processing: bool = False, + use_bbox: bool = False, + use_point: bool = False, ) -> Tuple[Dict[str, Any], Dict[str, Any], List[Dict[str, Any]]]: """Pre-process function of OpenVINO Visual Prompting Inferencer for image encoder.""" + if use_bbox and use_point: + logger.warning("If both use_bbox and use_point are set, bboxes and points will be generated randomly.") + + prob = 1.0 if not use_point else 0.0 if not use_bbox and use_point else 0.5 images, meta = self.model["image_encoder"].preprocess(dataset_item.numpy, extra_processing) - prompts = OTXVisualPromptingDataset.get_prompts(dataset_item, self.labels) # to be replaced + prompts = OTXVisualPromptingDataset.get_prompts(dataset_item, self.labels, prob=prob) prompts = self.model["decoder"].preprocess(prompts, meta) return images, meta, prompts # type: ignore @@ -176,12 +187,12 @@ def predict(self, dataset_item: DatasetItemEntity) -> List[Annotation]: # type: soft_predictions: List[np.ndarray] = [] for prompt in prompts: label = prompt.pop("label") - orig_size = prompt.pop("orig_size") prompt.update(image_embeddings) # forward decoder to get predicted mask prediction = self.forward_decoder(prompt) - metadata = {"label": label, "original_size": orig_size} + prediction["scores"] = prediction["iou_predictions"] + metadata = {"label": label} # set annotation for eval annotation, hard_prediction, soft_prediction = self.post_process(prediction, metadata) @@ -203,10 +214,6 @@ def await_all(self) -> None: self.model["image_encoder"].await_all() self.model["decoder"].await_all() - def pre_process_prompt_getter(self, *args, **kwargs) -> Any: - """Pre-process function of OpenVINO Zero-shot VIsual Prompting Inferencer for prompt getter.""" - pass - class OpenVINOZeroShotVisualPromptingInferencer(OpenVINOVisualPromptingInferencer): """Inferencer implementation for Zero-shot Visual Prompting using OpenVINO backend. @@ -250,7 +257,13 @@ def __init__( **attr.asdict( hparams.postprocessing, filter=lambda attr, value: attr.name - in ["image_size", "sim_threshold", "num_bg_points", "embedded_processing"], + in [ + "image_size", + "sim_threshold", + "num_bg_points", + "embedded_processing", + "default_threshold_reference", + ], ) }, "decoder": { @@ -289,44 +302,107 @@ def __init__( self.point_labels_box = np.array([[2, 3]], dtype=np.float32) self.has_mask_inputs = [np.array([[0.0]]), np.array([[1.0]])] - def pre_process( # type: ignore - self, dataset_item: DatasetItemEntity, extra_processing: bool = False + self.reference_feats: Optional[np.ndarray] = None + self.used_indices: Optional[np.ndarray] = None + + def pre_process_image_encoder( + self, inputs: np.ndarray, extra_processing: bool = False ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: """Pre-process function of OpenVINO Zero-shot Visual Prompting Inferencer for image encoder.""" - return self.model["image_encoder"].preprocess(dataset_item.numpy, extra_processing) + return self.model["image_encoder"].preprocess(inputs, extra_processing) - def pre_process_prompt_getter( - self, image_embeddings: Dict[str, np.ndarray], original_size: np.ndarray - ) -> Dict[str, np.ndarray]: - """Pre-process function of OpenVINO Zero-shot VIsual Prompting Inferencer for prompt getter.""" - inputs_prompt_getter = { - "original_size": original_size[None], - "threshold": np.array([[self.model["prompt_getter"].sim_threshold]], dtype=np.float32), - "num_bg_points": np.array([[self.model["prompt_getter"].num_bg_points]], dtype=np.int64), - } - inputs_prompt_getter.update(image_embeddings) - return inputs_prompt_getter + def learn( + self, + dataset_item: DatasetItemEntity, + reset_feat: bool = False, + use_bbox: bool = False, + use_point: bool = False, + path_reference_info: str = "vpm_zsl_reference_infos/{}/reference_info.pickle", + ) -> Tuple[Dict[str, np.ndarray], np.ndarray]: + """Learn for reference features.""" + ref_masks: np.ndarray + + if reset_feat or self.reference_feats is None: + self.initialize_reference_info() + + images, meta, prompts = self.pre_process(dataset_item, use_bbox, use_point) + largest_label: int = max([int(p["label"].id) for p in prompts]) + self.expand_reference_info(largest_label) - def predict(self, dataset_item: DatasetItemEntity) -> List[Annotation]: # type: ignore + image_embeddings = self.forward_image_encoder(images) + processed_embedding = image_embeddings["image_embeddings"].squeeze().transpose(1, 2, 0) + original_size = meta["original_shape"][:2] + + ref_masks = np.zeros((largest_label + 1, *map(int, original_size)), dtype=np.uint8) + for prompt in prompts: + if "point_coords" in prompt: + # bboxes and points + label = prompt.pop("label") + original_size = prompt.get("orig_size") + prompt.update(image_embeddings) + + prediction = self.forward_decoder(prompt, original_size, is_cascade=False) + ref_mask = prediction["upscaled_masks"] + else: + logger.warning("annotation and polygon will be supported.") + continue + ref_masks[int(label.id)] += ref_mask + + ref_masks = np.clip(ref_masks, 0, 1) + for label in range(largest_label + 1): + ref_mask = ref_masks[label] + if ref_mask.sum() == 0: + # empty prediction + continue + + ref_feat = None + default_threshold_reference = deepcopy(self.model["prompt_getter"].default_threshold_reference) + while ref_feat is None: + logger.info(f"[*] default_threshold_reference : {default_threshold_reference:.4f}") + ref_feat = self._generate_masked_features( + processed_embedding, ref_masks[label], default_threshold_reference + ) + default_threshold_reference -= 0.05 + + self.reference_feats[label] = ref_feat + self.used_indices = np.concatenate((self.used_indices, np.array([[label]])), axis=1) + + reference_info = {"reference_feats": self.reference_feats, "used_indices": self.used_indices} + path_reference_info = path_reference_info.format(time.strftime("%Y%m%d-%H%M%S")) + logger.info(f"Saved reference info at {path_reference_info}.") + pickle.dump(reference_info, open(path_reference_info, "wb")) + return reference_info, ref_masks + + def infer( + self, + images: np.ndarray, + reference_feats: np.ndarray, + used_indices: np.ndarray, + is_cascade: bool = False, + ) -> Tuple[List[Any], DefaultDict[Any, Any], DefaultDict[Any, Any]]: """Perform a prediction for a given input image.""" + points_score: np.ndarray + # forward image encoder - images, meta = self.pre_process(dataset_item) - original_size = np.array(meta["original_shape"][:2], dtype=np.int64) + images, meta = self.pre_process_image_encoder(images) + original_size = np.asarray([meta["original_shape"][:2]], dtype=np.int64) image_embeddings = self.forward_image_encoder(images) # get point candidates - inputs_prompt_getter = self.pre_process_prompt_getter(image_embeddings, original_size) - total_prompts = self.forward_prompt_getter(inputs_prompt_getter) + total_points_scores, total_bg_coords = self.forward_prompt_getter( + image_embeddings, reference_feats, used_indices, original_size + ) annotations: DefaultDict = defaultdict(list) predicted_masks: DefaultDict = defaultdict(list) used_points: DefaultDict = defaultdict(list) - for label, (points_scores, bg_coords) in enumerate( - zip(total_prompts["total_points_scores"], total_prompts["total_bg_coords"]) - ): + for label in total_points_scores.keys(): + points_scores = total_points_scores[label] + bg_coords = total_bg_coords[label] for points_score in points_scores: - if points_score[-1] == -1: + if points_score[-1] in [-1.0, 0.0]: continue + x, y = points_score[:2] is_done = False for pm in predicted_masks.get(label, []): @@ -338,37 +414,76 @@ def predict(self, dataset_item: DatasetItemEntity) -> List[Annotation]: # type: continue point_coords = np.concatenate((np.array([[x, y]]), bg_coords), axis=0, dtype=np.float32) - point_coords = self.model["decoder"]._apply_coords(point_coords, original_size) + point_coords = self.model["decoder"]._apply_coords(point_coords, original_size[0]) point_labels = np.array([1] + [0] * len(bg_coords), dtype=np.float32) - inputs_decoder = {"point_coords": point_coords[None], "point_labels": point_labels[None]} + inputs_decoder = { + "point_coords": point_coords[None], + "point_labels": point_labels[None], + "orig_size": original_size, + } inputs_decoder.update(image_embeddings) - prediction = self.forward_decoder(inputs_decoder, original_size) - metadata = { - "label": [_label for _label in self.labels if int(_label.id_) == label][0], - "original_size": original_size[None], - } + prediction = self.forward_decoder(inputs_decoder, original_size, is_cascade) + prediction.update({"scores": points_score[-1]}) - # set annotation for eval - annotation, hard_prediction, _ = self.post_process(prediction, metadata) - annotations[label].extend(annotation) - predicted_masks[label].append(hard_prediction) + predicted_masks[label].append(prediction[self.model["decoder"].output_blob_name]) used_points[label].append(points_score) - self.__inspect_overlapping_areas(predicted_masks, used_points, annotations) - return sum(annotations.values(), []) - def forward_prompt_getter(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + self._inspect_overlapping_areas(predicted_masks, used_points) + + for label, predictions in predicted_masks.items(): + if len(predictions) == 0: + continue + metadata = { + "label": [_label for _label in self.labels if int(_label.id_) == label][0], + "original_size": original_size, + } + for prediction, used_point in zip(predictions, used_points[label]): + annotation, _, _ = self.post_process( + {self.model["decoder"].output_blob_name: prediction, "scores": used_point[-1]}, metadata + ) + annotations[label].extend(annotation) + + return sum(annotations.values(), []), predicted_masks, used_points + + def forward_prompt_getter( + self, + image_embeddings: Dict[str, np.ndarray], + reference_feats: np.ndarray, + used_indices: np.ndarray, + original_size: np.ndarray, + ) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]: """Forward function of OpenVINO Visual Prompting Inferencer.""" - return self.model["prompt_getter"].infer_sync(inputs) + inputs = { + "original_size": original_size, + "threshold": np.array([[self.model["prompt_getter"].sim_threshold]], dtype=np.float32), + "num_bg_points": np.array([[self.model["prompt_getter"].num_bg_points]], dtype=np.int64), + **image_embeddings, + } + total_points_scores: Dict[int, np.ndarray] = {} + total_bg_coords: Dict[int, np.ndarray] = {} + for label in used_indices[0]: + reference_feat = reference_feats[label] + inputs["reference_feat"] = reference_feat + outputs = self.model["prompt_getter"].infer_sync(inputs) + + total_points_scores[label] = outputs["points_scores"] + total_bg_coords[label] = outputs["bg_coords"] + + return total_points_scores, total_bg_coords def forward_decoder( # type: ignore - self, inputs: Dict[str, np.ndarray], original_size: np.ndarray + self, + inputs: Dict[str, np.ndarray], + original_size: np.ndarray, + is_cascade: bool = True, ) -> Dict[str, np.ndarray]: """Forward function of OpenVINO Visual Prompting Inferencer.""" + masks: np.ndarray logits: np.ndarray scores: np.ndarray - mask_slice = slice(0, 1) - for i in range(3): + num_iter = 3 if is_cascade else 1 + for i in range(num_iter): if i == 0: # First-step prediction mask_input = np.zeros( @@ -378,44 +493,46 @@ def forward_decoder( # type: ignore elif i == 1: # Cascaded Post-refinement-1 - mask_input, masks, iou_predictions = self._postprocess_masks( - logits, scores, original_size, is_single=True # noqa: F821 - ) + mask_input, masks = self._postprocess_masks(masks, logits, scores, is_single=True) # noqa: F821 if masks.sum() == 0: - return {"iou_predictions": iou_predictions, "low_res_masks": mask_input} + return {"upscaled_masks": masks} has_mask_input = self.has_mask_inputs[1] elif i == 2: # Cascaded Post-refinement-2 - mask_input, masks, iou_predictions = self._postprocess_masks( - logits, scores, original_size # noqa: F821 - ) + mask_input, masks = self._postprocess_masks(masks, logits, scores) # noqa: F821 if masks.sum() == 0: - return {"iou_predictions": iou_predictions, "low_res_masks": mask_input} + return {"upscaled_masks": masks} has_mask_input = self.has_mask_inputs[1] y, x = np.nonzero(masks) box_coords = self.model["decoder"]._apply_coords( - np.array([[[x.min(), y.min()], [x.max(), y.max()]]], dtype=np.float32), original_size + np.array([[[x.min(), y.min()], [x.max(), y.max()]]], dtype=np.float32), original_size[0] + ) + inputs.update( + { + "point_coords": np.concatenate((inputs["point_coords"], box_coords), axis=1), + "point_labels": np.concatenate((inputs["point_labels"], self.point_labels_box), axis=1), + } ) - inputs["point_coords"] = np.concatenate((inputs["point_coords"], box_coords), axis=1) - inputs["point_labels"] = np.concatenate((inputs["point_labels"], self.point_labels_box), axis=1) inputs.update({"mask_input": mask_input, "has_mask_input": has_mask_input}) prediction = self.model["decoder"].infer_sync(inputs) - scores, logits = prediction["iou_predictions"], prediction["low_res_masks"] + upscaled_masks, scores, logits = ( + prediction["upscaled_masks"], + prediction["iou_predictions"], + prediction["low_res_masks"], + ) + masks = upscaled_masks > self.model["decoder"].mask_threshold - return {"iou_predictions": scores[:, mask_slice], "low_res_masks": logits[:, mask_slice, :, :]} + _, masks = self._postprocess_masks(masks, logits, scores) + return {"upscaled_masks": masks} def _postprocess_masks( - self, logits: np.ndarray, scores: np.ndarray, original_size: np.ndarray, is_single: bool = False + self, masks: np.ndarray, logits: np.ndarray, scores: np.ndarray, is_single: bool = False ) -> Tuple[np.ndarray, ...]: """Post-process logits for resized masks according to best index based on scores.""" - high_res_masks = self.model["decoder"].resize_and_crop(logits[0].transpose(1, 2, 0), original_size) - masks = high_res_masks > self.model["decoder"].mask_threshold - masks = masks.transpose(2, 0, 1)[None] - if is_single: best_idx = 0 else: @@ -430,19 +547,18 @@ def _postprocess_masks( if len(scores[0]) == 0: # all predicted masks were zero masks, ignore them. - return None, np.zeros((self.model["decoder"].image_size, self.model["decoder"].image_size)), 0.0 + return None, np.zeros(masks.shape[-2:]) best_idx = np.argmax(scores[0]) - return logits[:, [best_idx]], masks[0, best_idx], scores[0, best_idx] + return logits[:, [best_idx]], masks[0, best_idx] - def __inspect_overlapping_areas( + def _inspect_overlapping_areas( self, predicted_masks: Dict[int, List[np.ndarray]], used_points: Dict[int, List[np.ndarray]], - annotations: Dict[int, List[np.ndarray]], threshold_iou: float = 0.8, ): - def __calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray): + def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray): assert mask1.ndim == 2 and mask2.ndim == 2 intersection = np.logical_and(mask1, mask2).sum().item() union = np.logical_or(mask1, mask2).sum().item() @@ -460,21 +576,104 @@ def __calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray): overlapped_label = [] overlapped_other_label = [] for (im, mask), (jm, other_mask) in product(enumerate(masks), enumerate(other_masks)): - if __calculate_mask_iou(mask, other_mask) > threshold_iou: + if _calculate_mask_iou(mask, other_mask) > threshold_iou: if used_points[label][im][2] > used_points[other_label][jm][2]: overlapped_other_label.append(jm) else: overlapped_label.append(im) - for im in overlapped_label[::-1]: + for im in sorted(list(set(overlapped_label)), reverse=True): masks.pop(im) used_points[label].pop(im) - annotations[label].pop(im) - for jm in overlapped_other_label[::-1]: + for jm in sorted(list(set(overlapped_other_label)), reverse=True): other_masks.pop(jm) used_points[other_label].pop(jm) - annotations[other_label].pop(jm) + + def predict(self, dataset_item: DatasetItemEntity) -> List[Annotation]: # type: ignore + """Perform a prediction for a given input image.""" + results = self.infer(dataset_item.numpy, self.reference_feats, self.used_indices) + return results[0] + + def _find_latest_reference_info(self, root: str = "vpm_zsl_reference_infos") -> Union[str, None]: + """Find latest reference info to be used.""" + if not os.path.isdir(root): + return None + if len(stamps := sorted(os.listdir(root), reverse=True)) > 0: + return stamps[0] + return None + + def _get_reference_info( + self, root: str = "vpm_zsl_reference_infos", path_reference_info: str = "{}/reference_info.pickle" + ) -> Union[Tuple[np.ndarray, np.ndarray], None]: + """Get reference info through loading previously saved one or running `learn`.""" + if (latest_stamp := self._find_latest_reference_info(root)) is not None: + # load previously saved reference info + latest_reference_info = os.path.join(root, path_reference_info.format(latest_stamp)) + reference_info = pickle.load(open(latest_reference_info, "rb")) + return reference_info["reference_feats"], reference_info["used_indices"] + return None, None + + def initialize_reference_info(self) -> None: + """Initialize reference information.""" + self.reference_feats = np.zeros((0, 1, 256), dtype=np.float32) + self.used_indices = np.array([[]], dtype=np.int64) + + def expand_reference_info(self, new_largest_label: int) -> None: + """Expand reference info dimensions if newly given processed prompts have more lables.""" + if new_largest_label > (cur_largest_label := len(self.reference_feats) - 1): + diff = new_largest_label - cur_largest_label + self.reference_feats = np.pad(self.reference_feats, ((0, diff), (0, 0), (0, 0)), constant_values=0.0) + + def _generate_masked_features( + self, + feats: np.ndarray, + masks: np.ndarray, + threshold_mask: float, + ) -> Tuple[np.ndarray, ...]: + """Generate masked features. + + Args: + feats (np.ndarray): Raw reference features. It will be filtered with masks. + masks (np.ndarray): Reference masks used to filter features. + threshold_mask (float): Threshold to control masked region. + + Returns: + (np.ndarray): Masked features. + """ + target_shape = self.model["image_encoder"].image_size / max(masks.shape) * np.array(masks.shape) + target_shape = target_shape[::-1].astype(np.int32) + + # Post-process masks + masks = cv2.resize(masks, target_shape, interpolation=cv2.INTER_LINEAR) + masks = self._pad_to_square(masks) + masks = cv2.resize(masks, feats.shape[:2][::-1], interpolation=cv2.INTER_LINEAR) + + # Target feature extraction + if (masks > threshold_mask).sum() == 0: + # (for stability) there is no area to be extracted + return None + + masked_feat = feats[masks > threshold_mask] + masked_feat = masked_feat.mean(0)[None] + masked_feat = masked_feat / np.linalg.norm(masked_feat, axis=-1, keepdims=True) + + return masked_feat + + def _pad_to_square(self, x: np.ndarray) -> np.ndarray: + """Pad to a square input. + + Args: + x (np.ndarray): Mask to be padded. + + Returns: + (np.ndarray): Padded mask. + """ + h, w = x.shape[-2:] + padh = self.model["image_encoder"].image_size - h + padw = self.model["image_encoder"].image_size - w + x = np.pad(x, ((0, padh), (0, padw)), constant_values=0.0) + return x class OTXOpenVinoDataLoader: @@ -487,6 +686,7 @@ def __init__( module_name: str, shuffle: bool = True, output_model: Optional[ModelEntity] = None, + **kwargs, ): self.dataset = dataset self.inferencer = inferencer @@ -526,7 +726,6 @@ def __getitem__(self, index: int): image_embeddings = self.image_encoder(images["images"]) prompt = prompts[0] # only use the first prompt prompt.pop("label") - prompt.pop("orig_size") prompt.update({"image_embeddings": image_embeddings["image_embeddings"]}) return prompt # TODO (sungchul): change has_mask_input @@ -546,6 +745,8 @@ def __init__( module_name: str, shuffle: bool = True, output_model: Optional[ModelEntity] = None, + reference_feats: Optional[np.ndarray] = None, + used_indices: Optional[np.ndarray] = None, ): super().__init__( dataset=dataset, inferencer=inferencer, module_name=module_name, shuffle=shuffle, output_model=output_model @@ -553,6 +754,10 @@ def __init__( if self.module_name == "decoder": self.prompt_getter = self._load_module("prompt_getter", output_model) + self.inferencer: OpenVINOZeroShotVisualPromptingInferencer + self.reference_feats = reference_feats + self.used_indices = used_indices + def __getitem__(self, index: int) -> Dict[str, Any]: """Get item from dataset.""" images: Dict[str, np.ndarray] @@ -561,8 +766,8 @@ def __getitem__(self, index: int) -> Dict[str, Any]: index = self.shuffler[index] items = self.dataset[index] - images, meta = self.inferencer.pre_process(items, extra_processing=True) # type: ignore - original_size = np.array(meta["original_shape"][:2]) + images, meta = self.inferencer.pre_process_image_encoder(items.numpy, extra_processing=True) # type: ignore + original_size = np.asarray([meta["original_shape"][:2]]) _, _, h, w = images["images"].shape pad_width = ((0, 0), (0, 0), (0, self.target_length - h), (0, self.target_length - w)) images["images"] = np.pad(images["images"], pad_width, mode="constant", constant_values=0) @@ -570,23 +775,32 @@ def __getitem__(self, index: int) -> Dict[str, Any]: return images else: image_embeddings = self.image_encoder(images["images"]) - inputs_prompt_getter = self.inferencer.pre_process_prompt_getter(image_embeddings, original_size) if self.module_name == "prompt_getter": - return inputs_prompt_getter + return { + "reference_feat": self.reference_feats[self.used_indices[0][0]], # only use the first feature + "original_size": original_size, + "threshold": np.array([[self.inferencer.model["prompt_getter"].sim_threshold]], dtype=np.float32), + "num_bg_points": np.array([[self.inferencer.model["prompt_getter"].num_bg_points]], dtype=np.int64), + **image_embeddings, + } + + total_points_scores, total_bg_coords = self.inferencer.forward_prompt_getter( + image_embeddings, self.reference_feats, self.used_indices, original_size + ) - total_prompts = self.prompt_getter(inputs_prompt_getter) # only use the first prompt - point_score = total_prompts["total_points_scores"][0][0] - bg_coords = total_prompts["total_bg_coords"][0] + point_score: np.ndarray = total_points_scores[0][0] + bg_coords: np.ndarray = total_bg_coords[0] x, y = point_score[:2] point_coords = np.concatenate((np.array([[x, y]]), bg_coords), axis=0, dtype=np.float32) - point_coords = self.inferencer.model["decoder"]._apply_coords(point_coords, original_size) + point_coords = self.inferencer.model["decoder"]._apply_coords(point_coords, original_size[0]) point_labels = np.array([1] + [0] * len(bg_coords), dtype=np.float32) inputs_decoder = {"point_coords": point_coords[None], "point_labels": point_labels[None]} inputs_decoder.update(image_embeddings) inputs_decoder.update( { + "orig_size": original_size, "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), "has_mask_input": np.zeros((1, 1), dtype=np.float32), } @@ -771,6 +985,7 @@ def optimize( optimization_parameters: Optional[OptimizationParameters] = None, module_names: List[str] = ["image_encoder", "decoder"], ov_dataloader: Type[OTXOpenVinoDataLoader] = OTXOpenVinoDataLoader, + **kwargs, ): """Optimize function of OpenVINOVisualPromptingTask.""" logger.info("Start PTQ optimization") @@ -783,7 +998,9 @@ def optimize( dataset = dataset.get_subset(Subset.TRAINING) for i, module_name in enumerate(module_names, 1): - data_loader = ov_dataloader(dataset, self.inferencer, module_name=module_name, output_model=output_model) + data_loader = ov_dataloader( + dataset, self.inferencer, module_name=module_name, output_model=output_model, **kwargs + ) quantization_dataset = nncf.Dataset(data_loader, lambda data: data) with tempfile.TemporaryDirectory() as tempdir: @@ -863,6 +1080,71 @@ def load_inferencer(self) -> OpenVINOZeroShotVisualPromptingInferencer: num_requests=get_default_async_reqs_num(), ) + def infer( + self, + dataset: DatasetEntity, + inference_parameters: Optional[InferenceParameters] = None, + root: str = "vpm_zsl_reference_infos", + path_reference_info: str = "{}/reference_info.pickle", + ) -> DatasetEntity: + """Infer function of OpenVINOVisualPromptingTask. + + Currently, asynchronous execution is not supported, synchronous execution will be executed instead. + """ + if inference_parameters is not None: + update_progress_callback = inference_parameters.update_progress + enable_async_inference = inference_parameters.enable_async_inference + else: + update_progress_callback = default_progress_callback + enable_async_inference = True + + # FIXME (sungchul): Support async inference. + if enable_async_inference: + logger.warning("Asynchronous inference doesn't work, synchronous inference will be executed.") + enable_async_inference = False + predicted_validation_dataset = dataset.with_empty_annotations() + + def add_prediction(id: int, annotations: List[Annotation]): + dataset_item = predicted_validation_dataset[id] + dataset_item.append_annotations(annotations) + + total_time = 0.0 + dataset_size = len(dataset) + + if self.inferencer.reference_feats is None and self.inferencer.used_indices is None: + # set reference_feats and used_indices from previously saved reference_info + self.inferencer.reference_feats, self.inferencer.used_indices = self.inferencer._get_reference_info( + root, path_reference_info + ) + if self.inferencer.reference_feats is None and self.inferencer.used_indices is None: + # if they are empty, stop inference and return empty dataset + logger.warning( + ( + "reference_feats and used_indices are empty, stop inference and return empty dataset. " + "Please run learn function first." + ) + ) + return predicted_validation_dataset + + for i, dataset_item in enumerate(dataset, 1): + start_time = time.perf_counter() + + annotations = self.inferencer.predict(dataset_item) + add_prediction(i - 1, annotations) + + end_time = time.perf_counter() - start_time + total_time += end_time + update_progress_callback(int(i / dataset_size * 100), None) + + self.inferencer.await_all() + + self._avg_time_per_image = total_time / len(dataset) + logger.info(f"Avg time per image: {self._avg_time_per_image} secs") + logger.info(f"Total time: {total_time} secs") + logger.info("Visual Prompting OpenVINO inference completed") + + return predicted_validation_dataset + def optimize( self, optimization_type: OptimizationType, @@ -871,8 +1153,11 @@ def optimize( optimization_parameters: Optional[OptimizationParameters] = None, module_names: List[str] = ["image_encoder", "prompt_getter", "decoder"], ov_dataloader: Type[OTXOpenVinoDataLoader] = OTXZeroShotOpenVinoDataLoader, + **kwargs, ): """Optimize function of OpenVINOZeroShotVisualPromptingTask.""" + self.inferencer: OpenVINOZeroShotVisualPromptingInferencer + reference_feats, used_indices = self.inferencer._get_reference_info() return super().optimize( optimization_type=optimization_type, dataset=dataset, @@ -880,4 +1165,6 @@ def optimize( optimization_parameters=optimization_parameters, module_names=module_names, ov_dataloader=ov_dataloader, + reference_feats=reference_feats, + used_indices=used_indices, ) diff --git a/tests/assets/car_tree_bug_zero_shot/annotations/instances_train.json b/tests/assets/car_tree_bug_zero_shot/annotations/instances_train.json new file mode 100644 index 00000000000..39dadb88943 --- /dev/null +++ b/tests/assets/car_tree_bug_zero_shot/annotations/instances_train.json @@ -0,0 +1,66 @@ +{ + "licenses": [{ "name": "", "id": 0, "url": "" }], + "info": { + "contributor": "", + "date_created": "", + "description": "", + "url": "", + "version": "", + "year": "" + }, + "categories": [ + { "id": 1, "name": "car", "supercategory": "" }, + { "id": 2, "name": "tree", "supercategory": "" }, + { "id": 3, "name": "bug", "supercategory": "" } + ], + "images": [ + { + "id": 6, + "width": 1280, + "height": 720, + "file_name": "Slide4.PNG", + "license": 0, + "flickr_url": "", + "coco_url": "", + "date_captured": 0 + } + ], + "annotations": [ + { + "id": 16, + "image_id": 6, + "category_id": 3, + "segmentation": [ + [251.2, 150.5, 372.47, 47.31, 596.99, 231.4, 455.05, 376.77] + ], + "area": 53610.0, + "bbox": [251.2, 47.31, 345.79, 329.46], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 17, + "image_id": 6, + "category_id": 2, + "segmentation": [ + [641.72, 255.48, 731.18, 87.74, 848.17, 144.52, 746.67, 311.4] + ], + "area": 23927.0, + "bbox": [641.72, 87.74, 206.45, 223.66], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 18, + "image_id": 6, + "category_id": 1, + "segmentation": [ + [791.4, 443.9, 984.95, 183.23, 1112.26, 273.55, 910.97, 536.77] + ], + "area": 50412.0, + "bbox": [791.4, 183.23, 320.86, 353.54], + "iscrowd": 0, + "attributes": { "occluded": false } + } + ] +} diff --git a/tests/assets/car_tree_bug_zero_shot/annotations/instances_val.json b/tests/assets/car_tree_bug_zero_shot/annotations/instances_val.json new file mode 100644 index 00000000000..bdbfe9331bf --- /dev/null +++ b/tests/assets/car_tree_bug_zero_shot/annotations/instances_val.json @@ -0,0 +1,316 @@ +{ + "licenses": [{ "name": "", "id": 0, "url": "" }], + "info": { + "contributor": "", + "date_created": "", + "description": "", + "url": "", + "version": "", + "year": "" + }, + "categories": [ + { "id": 1, "name": "car", "supercategory": "" }, + { "id": 2, "name": "tree", "supercategory": "" }, + { "id": 3, "name": "bug", "supercategory": "" } + ], + "images": [ + { + "id": 7, + "width": 1280, + "height": 720, + "file_name": "Slide3.PNG", + "license": 0, + "flickr_url": "", + "coco_url": "", + "date_captured": 0 + }, + { + "id": 8, + "width": 1280, + "height": 720, + "file_name": "Slide4.PNG", + "license": 0, + "flickr_url": "", + "coco_url": "", + "date_captured": 0 + }, + { + "id": 1, + "width": 1280, + "height": 720, + "file_name": "Slide9.PNG", + "license": 0, + "flickr_url": "", + "coco_url": "", + "date_captured": 0 + }, + { + "id": 2, + "width": 1280, + "height": 720, + "file_name": "Slide8.PNG", + "license": 0, + "flickr_url": "", + "coco_url": "", + "date_captured": 0 + }, + { + "id": 3, + "width": 1280, + "height": 720, + "file_name": "Slide7.PNG", + "license": 0, + "flickr_url": "", + "coco_url": "", + "date_captured": 0 + }, + { + "id": 4, + "width": 1280, + "height": 720, + "file_name": "Slide6.PNG", + "license": 0, + "flickr_url": "", + "coco_url": "", + "date_captured": 0 + }, + { + "id": 5, + "width": 1280, + "height": 720, + "file_name": "Slide5.PNG", + "license": 0, + "flickr_url": "", + "coco_url": "", + "date_captured": 0 + } + ], + "annotations": [ + { + "id": 19, + "image_id": 7, + "category_id": 1, + "segmentation": [ + [184.09, 131.61, 338.06, 129.89, 339.78, 457.63, 183.23, 461.08] + ], + "area": 51030.0, + "bbox": [183.23, 129.89, 156.55, 331.19], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 20, + "image_id": 7, + "category_id": 2, + "segmentation": [ + [832.69, 104.09, 1018.49, 102.37, 1017.63, 226.24, 825.81, 233.98] + ], + "area": 23933.0, + "bbox": [825.81, 102.37, 192.68, 131.61], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 21, + "image_id": 7, + "category_id": 3, + "segmentation": [ + [898.92, 490.32, 1195.7, 487.74, 1209.46, 673.55, 913.55, 670.11] + ], + "area": 54157.0, + "bbox": [898.92, 487.74, 310.54, 185.81], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 22, + "image_id": 8, + "category_id": 3, + "segmentation": [ + [341.51, 373.33, 502.4, 456.8, 341.5, 709.7, 188.39, 612.47] + ], + "area": 52814.0, + "bbox": [188.39, 373.33, 314.01, 336.37], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 1, + "image_id": 1, + "category_id": 3, + "segmentation": [ + [17.2, 166.88, 203.87, 7.74, 410.32, 43.87, 117.85, 331.18] + ], + "area": 58273.0, + "bbox": [17.2, 7.74, 393.12, 323.44], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 2, + "image_id": 1, + "category_id": 1, + "segmentation": [ + [294.19, 281.29, 643.44, 300.22, 628.82, 469.68, 277.85, 449.03] + ], + "area": 59331.0, + "bbox": [277.85, 281.29, 365.59, 188.39], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 3, + "image_id": 1, + "category_id": 2, + "segmentation": [ + [114.41, 499.79, 30.97, 670.11, 151.4, 705.38, 240.86, 536.77] + ], + "area": 24033.0, + "bbox": [30.97, 499.79, 209.89, 205.59], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 4, + "image_id": 2, + "category_id": 1, + "segmentation": [[165.16, 2.58, 344.95, 41.29, 27.5, 363.0, 9.46, 147.1]], + "area": 53173.0, + "bbox": [9.46, 2.58, 335.49, 360.42], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 5, + "image_id": 2, + "category_id": 2, + "segmentation": [ + [524.73, 378.49, 648.6, 227.96, 762.15, 298.49, 627.96, 458.49] + ], + "area": 26526.0, + "bbox": [524.73, 227.96, 237.42, 230.53], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 6, + "image_id": 2, + "category_id": 3, + "segmentation": [ + [946.24, 652.9, 1191.4, 356.13, 1274.8, 576.3, 1092.5, 715.7] + ], + "area": 55317.0, + "bbox": [946.24, 356.13, 328.56, 359.57], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 7, + "image_id": 3, + "category_id": 2, + "segmentation": [ + [584.95, 221.94, 715.7, 223.66, 706.24, 411.18, 583.23, 413.76] + ], + "area": 24074.0, + "bbox": [583.23, 221.94, 132.47, 191.82], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 8, + "image_id": 3, + "category_id": 1, + "segmentation": [ + [826.67, 222.8, 966.9, 176.3, 1081.29, 489.46, 931.61, 542.8] + ], + "area": 51362.0, + "bbox": [826.67, 176.3, 254.62, 366.5], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 9, + "image_id": 3, + "category_id": 3, + "segmentation": [ + [698.49, 384.52, 864.52, 390.54, 872.26, 688.17, 683.01, 683.01] + ], + "area": 52982.0, + "bbox": [683.01, 384.52, 189.25, 303.65], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 10, + "image_id": 4, + "category_id": 1, + "segmentation": [ + [69.68, 11.18, 67.1, 336.34, 213.33, 338.92, 222.8, 10.32] + ], + "area": 48945.0, + "bbox": [67.1, 10.32, 155.7, 328.6], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 11, + "image_id": 4, + "category_id": 2, + "segmentation": [ + [569.46, 70.54, 688.17, 70.54, 683.01, 262.37, 559.14, 263.23] + ], + "area": 23273.0, + "bbox": [559.14, 70.54, 129.03, 192.69], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 12, + "image_id": 4, + "category_id": 3, + "segmentation": [ + [972.04, 116.13, 1265.38, 95.48, 1274.84, 295.05, 974.62, 292.47] + ], + "area": 55841.0, + "bbox": [972.04, 95.48, 302.8, 199.57], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 13, + "image_id": 5, + "category_id": 3, + "segmentation": [ + [200.43, 336.34, 385.38, 334.62, 382.8, 635.7, 206.45, 638.28] + ], + "area": 54478.0, + "bbox": [200.43, 334.62, 184.95, 303.66], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 14, + "image_id": 5, + "category_id": 2, + "segmentation": [ + [594.41, 523.01, 779.35, 523.87, 778.49, 643.44, 590.97, 645.16] + ], + "area": 22525.0, + "bbox": [590.97, 523.01, 188.38, 122.15], + "iscrowd": 0, + "attributes": { "occluded": false } + }, + { + "id": 15, + "image_id": 5, + "category_id": 1, + "segmentation": [ + [1101.1, 304.6, 1230.1, 389.6, 1058.1, 665.8, 929.03, 581.51] + ], + "area": 50271.0, + "bbox": [929.03, 304.6, 301.07, 361.2], + "iscrowd": 0, + "attributes": { "occluded": false } + } + ] +} diff --git a/tests/assets/car_tree_bug_zero_shot/images/train/Slide4.PNG b/tests/assets/car_tree_bug_zero_shot/images/train/Slide4.PNG new file mode 100644 index 00000000000..236d848e36a Binary files /dev/null and b/tests/assets/car_tree_bug_zero_shot/images/train/Slide4.PNG differ diff --git a/tests/assets/car_tree_bug_zero_shot/images/val/Slide3.PNG b/tests/assets/car_tree_bug_zero_shot/images/val/Slide3.PNG new file mode 100644 index 00000000000..6cd6e10f702 Binary files /dev/null and b/tests/assets/car_tree_bug_zero_shot/images/val/Slide3.PNG differ diff --git a/tests/assets/car_tree_bug_zero_shot/images/val/Slide4.PNG b/tests/assets/car_tree_bug_zero_shot/images/val/Slide4.PNG new file mode 100644 index 00000000000..22960a01034 Binary files /dev/null and b/tests/assets/car_tree_bug_zero_shot/images/val/Slide4.PNG differ diff --git a/tests/assets/car_tree_bug_zero_shot/images/val/Slide5.PNG b/tests/assets/car_tree_bug_zero_shot/images/val/Slide5.PNG new file mode 100644 index 00000000000..fb979bc1220 Binary files /dev/null and b/tests/assets/car_tree_bug_zero_shot/images/val/Slide5.PNG differ diff --git a/tests/assets/car_tree_bug_zero_shot/images/val/Slide6.PNG b/tests/assets/car_tree_bug_zero_shot/images/val/Slide6.PNG new file mode 100644 index 00000000000..cc5e63a173f Binary files /dev/null and b/tests/assets/car_tree_bug_zero_shot/images/val/Slide6.PNG differ diff --git a/tests/assets/car_tree_bug_zero_shot/images/val/Slide7.PNG b/tests/assets/car_tree_bug_zero_shot/images/val/Slide7.PNG new file mode 100644 index 00000000000..93e61c87f48 Binary files /dev/null and b/tests/assets/car_tree_bug_zero_shot/images/val/Slide7.PNG differ diff --git a/tests/assets/car_tree_bug_zero_shot/images/val/Slide8.PNG b/tests/assets/car_tree_bug_zero_shot/images/val/Slide8.PNG new file mode 100644 index 00000000000..954d34d32e3 Binary files /dev/null and b/tests/assets/car_tree_bug_zero_shot/images/val/Slide8.PNG differ diff --git a/tests/assets/car_tree_bug_zero_shot/images/val/Slide9.PNG b/tests/assets/car_tree_bug_zero_shot/images/val/Slide9.PNG new file mode 100644 index 00000000000..da9448fc35f Binary files /dev/null and b/tests/assets/car_tree_bug_zero_shot/images/val/Slide9.PNG differ diff --git a/tests/e2e/cli/visual_prompting/reference/Zero_Shot_SAM_Tiny_ViT/compressed_decoder.yml b/tests/e2e/cli/visual_prompting/reference/Zero_Shot_SAM_Tiny_ViT/compressed_decoder.yml index 9009e81d953..bbefedd68ef 100644 --- a/tests/e2e/cli/visual_prompting/reference/Zero_Shot_SAM_Tiny_ViT/compressed_decoder.yml +++ b/tests/e2e/cli/visual_prompting/reference/Zero_Shot_SAM_Tiny_ViT/compressed_decoder.yml @@ -1,3 +1,3 @@ TestToolsZeroShotVisualPrompting: ptq: - number_of_fakequantizers: 69 + number_of_fakequantizers: 71 diff --git a/tests/e2e/cli/visual_prompting/reference/Zero_Shot_SAM_Tiny_ViT/compressed_prompt_getter.yml b/tests/e2e/cli/visual_prompting/reference/Zero_Shot_SAM_Tiny_ViT/compressed_prompt_getter.yml index 43f5b4c35ce..e1590800e6a 100644 --- a/tests/e2e/cli/visual_prompting/reference/Zero_Shot_SAM_Tiny_ViT/compressed_prompt_getter.yml +++ b/tests/e2e/cli/visual_prompting/reference/Zero_Shot_SAM_Tiny_ViT/compressed_prompt_getter.yml @@ -1,3 +1,3 @@ TestToolsZeroShotVisualPrompting: ptq: - number_of_fakequantizers: 1 + number_of_fakequantizers: 0 diff --git a/tests/e2e/cli/visual_prompting/test_zero_shot.py b/tests/e2e/cli/visual_prompting/test_zero_shot.py index e5471cc7bb7..c0fb4fdb5d3 100644 --- a/tests/e2e/cli/visual_prompting/test_zero_shot.py +++ b/tests/e2e/cli/visual_prompting/test_zero_shot.py @@ -23,10 +23,10 @@ ) args = { - "--train-data-roots": "tests/assets/car_tree_bug", - "--val-data-roots": "tests/assets/car_tree_bug", - "--test-data-roots": "tests/assets/car_tree_bug", - "--input": "tests/assets/car_tree_bug/images/train", + "--train-data-roots": "tests/assets/car_tree_bug_zero_shot", + "--val-data-roots": "tests/assets/car_tree_bug_zero_shot", + "--test-data-roots": "tests/assets/car_tree_bug_zero_shot", + "--input": "tests/assets/car_tree_bug_zero_shot/images/train", "train_params": [ "params", "--learning_parameters.trainer.max_epochs", @@ -100,7 +100,7 @@ def test_otx_eval_openvino(self, template, tmp_dir_path, half_precision): tmp_dir_path, otx_dir, args, - threshold=0.2, + threshold=0.3, half_precision=half_precision, is_visual_prompting=True, ) diff --git a/tests/integration/cli/visual_prompting/test_zero_shot.py b/tests/integration/cli/visual_prompting/test_zero_shot.py index ccedf5c2fa2..33b5b433628 100644 --- a/tests/integration/cli/visual_prompting/test_zero_shot.py +++ b/tests/integration/cli/visual_prompting/test_zero_shot.py @@ -17,10 +17,10 @@ ) args = { - "--train-data-roots": "tests/assets/car_tree_bug", - "--val-data-roots": "tests/assets/car_tree_bug", - "--test-data-roots": "tests/assets/car_tree_bug", - "--input": "tests/assets/car_tree_bug/images/train", + "--train-data-roots": "tests/assets/car_tree_bug_zero_shot", + "--val-data-roots": "tests/assets/car_tree_bug_zero_shot", + "--test-data-roots": "tests/assets/car_tree_bug_zero_shot", + "--input": "tests/assets/car_tree_bug_zero_shot/images/train", "train_params": [ "params", "--learning_parameters.trainer.max_epochs", diff --git a/tests/unit/algorithms/visual_prompting/adapters/openvino/model_wrappers/test_openvino_models.py b/tests/unit/algorithms/visual_prompting/adapters/openvino/model_wrappers/test_openvino_models.py index 7740de14ab9..1b70f12aa33 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/openvino/model_wrappers/test_openvino_models.py +++ b/tests/unit/algorithms/visual_prompting/adapters/openvino/model_wrappers/test_openvino_models.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 # -from typing import Tuple +from typing import Tuple, Dict, Any import numpy as np import pytest @@ -57,6 +57,26 @@ def test_parameters(self): assert params.get("sim_threshold").default_value == 0.5 assert params.get("num_bg_points").default_value == 1 + @e2e_pytest_unit + def test_get_inputs(self, mocker): + """Test _get_inputs.""" + mocker.patch.object(ImageModel, "__init__") + prompt_getter = PromptGetter("adapter") + + prompt_getter.inputs = { + "image_embeddings": np.ones((1, 4, 4, 3)), + "reference_feats": np.ones((2, 1, 256)), + "used_indices": np.array([[0, 1]], dtype=np.int64), + "original_size": np.array([[4, 4]], dtype=np.int64), + "threshold": np.array([[0.1]]), + "num_bg_points": np.array([[1]], dtype=np.int64), + } + + returned_value = prompt_getter._get_inputs() + + assert returned_value[0] == ["image_embeddings"] + assert returned_value[1] == ["reference_feats", "used_indices", "original_size", "threshold", "num_bg_points"] + class TestDecoder: @pytest.fixture(autouse=True) @@ -79,20 +99,42 @@ def test_get_outputs(self): """Test _get_outputs.""" results = self.decoder._get_outputs() - assert "low_res_masks" == results + assert "upscaled_masks" == results @e2e_pytest_unit - def test_preprocess(self): + @pytest.mark.parametrize( + "prompts,expected", + [ + ( + { + "bboxes": [np.array([[1, 1], [2, 2]])], + "points": [], + "labels": {"bboxes": [1]}, + "original_size": (4, 4), + }, + { + "point_coords": (1, 2, 2), + "point_labels": (1, 2), + }, + ), + ( + {"bboxes": [], "points": [np.array([[1, 1]])], "labels": {"points": [1]}, "original_size": (4, 4)}, + { + "point_coords": (1, 1, 2), + "point_labels": (1, 1), + }, + ), + ], + ) + def test_preprocess(self, prompts: Dict[str, Any], expected: Dict[str, Any]): """Test preprocess""" - prompts = {"bboxes": [np.array([[1, 1], [2, 2]])], "labels": [1], "original_size": (4, 4)} - results = self.decoder.preprocess(prompts, {}) assert isinstance(results, list) assert "point_coords" in results[0] - assert results[0]["point_coords"].shape == (1, 2, 2) + assert results[0]["point_coords"].shape == expected["point_coords"] assert "point_labels" in results[0] - assert results[0]["point_labels"].shape == (1, 2) + assert results[0]["point_labels"].shape == expected["point_labels"] assert "mask_input" in results[0] assert "has_mask_input" in results[0] assert "orig_size" in results[0] @@ -134,35 +176,11 @@ def test_get_inputs(self): @e2e_pytest_unit def test_postprocess(self, mocker): """Test postprocess.""" - self.decoder.output_blob_name = "masks" - self.decoder.soft_threshold = 0.5 + self.decoder.output_blob_name = "upscaled_masks" + self.decoder.mask_threshold = 0.0 self.decoder.blur_strength = 2 - fake_output = {"masks": np.ones((4, 4)), "iou_predictions": 0.1} + fake_output = {"upscaled_masks": np.ones((4, 4)), "scores": 0.1} fake_metadata = {"original_size": np.array([[6, 6]]), "label": mocker.Mock(spec=LabelEntity)} returned_value = self.decoder.postprocess(outputs=fake_output, meta=fake_metadata) assert isinstance(returned_value, tuple) - assert np.all(returned_value[0].shape == fake_metadata["original_size"]) - assert np.all(returned_value[1].shape == fake_metadata["original_size"]) - - @e2e_pytest_unit - def test_resize_and_crop(self, mocker): - """Test resize_and_crop.""" - mocker.patch.object(self.decoder, "get_padded_size", return_value=np.array((6, 6))) - - masks = np.zeros((2, 2)) - orig_size = np.array((8, 8)) - - results = self.decoder.resize_and_crop(masks, orig_size) - - assert results.shape == tuple(orig_size) - - @e2e_pytest_unit - def test_get_padded_size(self): - """Test get_padded_size.""" - original_size = np.array((2, 4)) - longest_side = 6 - - results = self.decoder.get_padded_size(original_size, longest_side) - - assert np.all(results == np.array((3, 6))) diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/callbacks/test_inference_callback.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/callbacks/test_inference_callback.py index 3cb1710d76e..8053f3dcc6c 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/callbacks/test_inference_callback.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/callbacks/test_inference_callback.py @@ -78,7 +78,13 @@ def test_on_predict_epoch_end(self, use_mask: bool, expected: Any): "masks": [torch.Tensor([[[0, 1, 0], [1, 1, 1], [0, 1, 0]]])], "iou_predictions": [torch.Tensor([[0.9]])], "labels": [ - [ScoredLabel(label=LabelEntity("foreground", domain=Domain.VISUAL_PROMPTING), probability=0.0)], + { + "bboxes": [ + ScoredLabel( + label=LabelEntity("foreground", domain=Domain.VISUAL_PROMPTING), probability=0.0 + ) + ], + } ], } ] diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_sam_transforms.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_sam_transforms.py index 2636627d2b9..dca279d5175 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_sam_transforms.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_sam_transforms.py @@ -20,10 +20,6 @@ class TestResizeLongestSide: def setup(self): self.resize_longest_side = ResizeLongestSide(8) - @e2e_pytest_unit - def test_call(self): - """Test __call__.""" - @e2e_pytest_unit @pytest.mark.parametrize( "image,expected", diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_transforms.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_transforms.py index 36225af9d4a..80890c5a155 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_transforms.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/test_transforms.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 # -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, List import pytest import torch @@ -21,59 +21,102 @@ @e2e_pytest_unit -def test_collate_fn(): +@pytest.mark.parametrize( + "batch,expected", + [ + ( + [ + { + "index": 0, + "images": Tensor([1, 2, 3]), + "bboxes": Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), + "points": torch.zeros((0, 2)), + "gt_masks": [Tensor([1, 2, 3])], + "original_size": Tensor([1, 3]), + "path": [], + "labels": [], + }, + { + "index": 1, + "images": Tensor([4, 5, 6]), + "bboxes": Tensor([[9, 10, 11, 12]]), + "points": torch.zeros((0, 2)), + "gt_masks": [Tensor([4, 5, 6])], + "original_size": Tensor([1, 3]), + "path": [], + "labels": [], + }, + ], + { + "index": [0, 1], + "images": Tensor([[1, 2, 3], [4, 5, 6]]), + "bboxes": [Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), Tensor([[9, 10, 11, 12]])], + "points": [None, None], + "gt_masks": [Tensor([[1, 2, 3]]), Tensor([[4, 5, 6]])], + "original_size": [Tensor([1, 3]), Tensor([1, 3])], + "path": [[], []], + "labels": [[], []], + }, + ), + ( + [ + { + "index": 0, + "images": Tensor([1, 2, 3]), + "bboxes": torch.zeros((0, 4)), + "points": Tensor([[1, 1]]), + "gt_masks": [Tensor([1, 2, 3])], + "original_size": Tensor([1, 3]), + "path": [], + "labels": [], + }, + { + "index": 1, + "images": Tensor([4, 5, 6]), + "bboxes": torch.zeros((0, 4)), + "points": Tensor([[2, 2]]), + "gt_masks": [Tensor([4, 5, 6])], + "original_size": Tensor([1, 3]), + "path": [], + "labels": [], + }, + ], + { + "index": [0, 1], + "images": Tensor([[1, 2, 3], [4, 5, 6]]), + "bboxes": [None, None], + "points": [Tensor([[1, 1]]), Tensor([[2, 2]])], + "gt_masks": [Tensor([[1, 2, 3]]), Tensor([[4, 5, 6]])], + "original_size": [Tensor([1, 3]), Tensor([1, 3])], + "path": [[], []], + "labels": [[], []], + }, + ), + ], +) +def test_collate_fn(batch: List[Dict[str, Any]], expected: Dict[str, Any]): """Test collate_fn.""" - batch = [ - { - "index": 0, - "images": Tensor([1, 2, 3]), - "bboxes": np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), - "points": [], - "gt_masks": [Tensor([1, 2, 3])], - "original_size": np.array([1, 3]), - "padding": [], - "path": [], - "labels": [], - }, - { - "index": 1, - "images": Tensor([4, 5, 6]), - "bboxes": np.array([[9, 10, 11, 12]]), - "points": [], - "gt_masks": [Tensor([4, 5, 6])], - "original_size": np.array([1, 3]), - "padding": [], - "path": [], - "labels": [], - }, - ] - expected = { - "index": [0, 1], - "images": Tensor([[1, 2, 3], [4, 5, 6]]), - "bboxes": [Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), Tensor([[9, 10, 11, 12]])], - "points": None, - "gt_masks": [Tensor([[1, 2, 3]]), Tensor([[4, 5, 6]])], - "original_size": [Tensor([1, 3]), Tensor([1, 3])], - "path": [[], []], - "labels": [[], []], - "padding": [[], []], - } - results = collate_fn(batch) assert results["index"] == expected["index"] assert torch.all(results["images"] == expected["images"]) for r, e in zip(results["bboxes"], expected["bboxes"]): - assert torch.all(r == e) - assert results["points"] == expected["points"] + if r is not None and e is not None: + assert torch.all(r == e) + + for r, e in zip(results["points"], expected["points"]): + if r is not None and e is not None: + assert torch.all(r == e) + assert len(results["gt_masks"]) == len(expected["gt_masks"]) for r, e in zip(results["gt_masks"], expected["gt_masks"]): assert torch.all(r == e) + for r, e in zip(results["original_size"], expected["original_size"]): assert torch.all(r == e) + assert results["path"] == expected["path"] assert results["labels"] == expected["labels"] - assert results["padding"] == expected["padding"] class TestPad: @@ -88,22 +131,21 @@ class TestPad: bboxes=[[1, 1, 3, 3]], points=[[1, 1, 2, 2]], ), - ((0, 0, 0, 2), (3, 6, 6), [(4, 6)], [[1, 1, 3, 3]], [[1, 1, 2, 2]]), + ((3, 6, 6), [(4, 6)], [[1, 1, 3, 3]], [[1, 1, 2, 2]]), ), ( dict(images=torch.zeros((3, 4, 6)), gt_masks=[torch.zeros((4, 6))], bboxes=[[1, 1, 3, 3]], points=None), - ((0, 0, 0, 2), (3, 6, 6), [(4, 6)], [[1, 1, 3, 3]], None), + ((3, 6, 6), [(4, 6)], [[1, 1, 3, 3]], None), ), ], ) def test_call(self, item: Dict[str, Any], expected: Tuple[Any]): """Test __call__.""" pad_transform = Pad() - expected_padding, expected_images_shape, expected_gt_masks_shape, expected_bboxes, expected_points = expected + expected_images_shape, expected_gt_masks_shape, expected_bboxes, expected_points = expected result = pad_transform(item) - assert result["padding"] == expected_padding assert result["images"].shape == expected_images_shape assert len(result["gt_masks"]) == len(expected_gt_masks_shape) assert all(gt_mask.shape == shape for gt_mask, shape in zip(result["gt_masks"], expected_gt_masks_shape)) diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/test_dataset.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/test_dataset.py index 99a76c3b17b..6e5211c1899 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/test_dataset.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/test_dataset.py @@ -19,7 +19,6 @@ generate_bbox, generate_bbox_from_mask, get_transform, - generate_point_from_mask, ) from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.pipelines import ( MultipleInputsCompose, @@ -146,11 +145,6 @@ def test_generate_bbox_from_mask(mocker) -> None: assert bbox[3] >= 0 and bbox[3] <= height -@e2e_pytest_unit -def test_generate_point_from_mask() -> None: - """TODO""" - - class TestOTXVIsualPromptingDataset: @e2e_pytest_unit def test_len(self, mocker, dataset_polygon, transform, image_size, mean, std) -> None: @@ -159,7 +153,7 @@ def test_len(self, mocker, dataset_polygon, transform, image_size, mean, std) -> "otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.dataset.get_transform", return_value=transform, ) - otx_dataset = OTXVisualPromptingDataset(dataset_polygon, image_size, mean, std) + otx_dataset = OTXVisualPromptingDataset("testing", dataset_polygon, image_size, mean, std) assert len(otx_dataset) == 4 @e2e_pytest_unit @@ -173,7 +167,7 @@ def test_getitem( return_value=transform, ) dataset = dataset_mask if use_mask else dataset_polygon - otx_dataset = OTXVisualPromptingDataset(dataset, image_size, mean, std) + otx_dataset = OTXVisualPromptingDataset("testing", dataset, image_size, mean, std) item = otx_dataset[0] @@ -189,7 +183,7 @@ def test_getitem( assert isinstance(item["gt_masks"], list) assert isinstance(item["gt_masks"][0], np.ndarray) assert isinstance(item["bboxes"], np.ndarray) - assert item["points"] == [] + assert len(item["points"]) == 0 class TestOTXZeroShotVisualPromptingDataset: @@ -209,7 +203,7 @@ def test_getitem( return_value=transform, ) dataset = dataset_mask if use_mask else dataset_polygon - otx_dataset = OTXZeroShotVisualPromptingDataset(dataset, image_size, mean, std) + otx_dataset = OTXZeroShotVisualPromptingDataset("testing", dataset, image_size, mean, std) item = otx_dataset[0] @@ -225,7 +219,7 @@ def test_getitem( assert isinstance(item["gt_masks"], list) assert isinstance(item["gt_masks"][0], np.ndarray) assert isinstance(item["bboxes"], np.ndarray) - assert item["points"] == [] + assert len(item["points"]) == 0 class TestOTXVisualPromptingDataModule: @@ -248,8 +242,8 @@ def test_init_zeroshot(self, set_datamodule): datamodule = set_datamodule(train_type=TrainType.Zeroshot) assert datamodule.config.get("train_batch_size") == 1 - assert "generate_point" in datamodule.kwargs - assert "generate_bbox" in datamodule.kwargs + assert "use_point" in datamodule.kwargs + assert "use_bbox" in datamodule.kwargs @e2e_pytest_unit def test_setup(self, mocker, set_datamodule) -> None: diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/encoders/test_sam_image_encoder.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/encoders/test_sam_image_encoder.py index e80137685d7..66ae6958f0b 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/encoders/test_sam_image_encoder.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/encoders/test_sam_image_encoder.py @@ -22,38 +22,22 @@ def forward(self, *args, **kwargs): class TestSAMImageEncoder: - @pytest.fixture(autouse=True) - def setup(self, mocker) -> None: - self.mocker_backbone = mocker.patch( - "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.encoders.sam_image_encoder.build_vit", - return_value=MockBackbone(), - ) - - self.base_config = DictConfig(dict(backbone="vit_b", image_size=1024)) + @pytest.fixture() + def config(self, mocker) -> DictConfig: + return DictConfig(dict(image_size=1024)) @e2e_pytest_unit - @pytest.mark.parametrize("backbone", ["vit_b", "resnet"]) - def test_init(self, backbone: str): - """Test init.""" - self.mocker_backbone.reset_mock() - - config = self.base_config.copy() - config.update(dict(backbone=backbone)) - - if backbone == "resnet": - with pytest.raises(NotImplementedError): - sam_image_encoder = SAMImageEncoder(config) - else: - sam_image_encoder = SAMImageEncoder(config) - self.mocker_backbone.assert_called_once() - - @e2e_pytest_unit - def test_forward(self, mocker): - """Test forward.""" - self.mocker_backbone.reset_mock() - - sam_image_encoder = SAMImageEncoder(self.base_config) - mocker_forward = mocker.patch.object(sam_image_encoder.backbone, "forward") - sam_image_encoder.forward(torch.Tensor([1.0])) - - mocker_forward.assert_called_once() + @pytest.mark.parametrize( + "backbone,expected", + [ + ("tiny_vit", "TinyViT"), + ("vit_b", "ViT"), + ], + ) + def test_new(self, config: DictConfig, backbone: str, expected: str) -> None: + """Test __new__.""" + config.update({"backbone": backbone}) + + sam_image_encoder = SAMImageEncoder(config) + + assert sam_image_encoder.__class__.__name__ == expected diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_segment_anything.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_segment_anything.py index fed22e060c8..eecddf412a1 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_segment_anything.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_segment_anything.py @@ -76,7 +76,7 @@ def test_set_models(self, mocker, backbone: str) -> None: # backbone == vit_b sam = SegmentAnything(config) - assert isinstance(sam.image_encoder, MockImageEncoder) + assert isinstance(sam.image_encoder, nn.Linear) assert isinstance(sam.prompt_encoder, MockPromptEncoder) assert isinstance(sam.mask_decoder, MockMaskDecoder) @@ -153,37 +153,21 @@ def test_set_metrics(self, mocker, loss_type: str): @e2e_pytest_unit @pytest.mark.parametrize( - "is_backbone_arg,state_dict", + "state_dict", [ - ( - False, - OrderedDict( - [ - ("image_encoder.weight", Tensor([[0.0]])), - ("image_encoder.bias", Tensor([0.0])), - ("prompt_encoder.layer.weight", Tensor([[0.0]])), - ("prompt_encoder.layer.bias", Tensor([0.0])), - ("mask_decoder.layer.weight", Tensor([[0.0]])), - ("mask_decoder.layer.bias", Tensor([0.0])), - ] - ), - ), - ( - True, - OrderedDict( - [ - ("image_encoder.backbone.weight", Tensor([[1.0]])), - ("image_encoder.backbone.bias", Tensor([1.0])), - ("prompt_encoder.layer.weight", Tensor([[1.0]])), - ("prompt_encoder.layer.bias", Tensor([1.0])), - ("mask_decoder.layer.weight", Tensor([[1.0]])), - ("mask_decoder.layer.bias", Tensor([1.0])), - ] - ), + OrderedDict( + [ + ("image_encoder.weight", torch.ones(4, 4)), + ("image_encoder.bias", torch.ones(4)), + ("prompt_encoder.layer.weight", Tensor([[1.0]])), + ("prompt_encoder.layer.bias", Tensor([1.0])), + ("mask_decoder.layer.weight", Tensor([[1.0]])), + ("mask_decoder.layer.bias", Tensor([1.0])), + ] ), ], ) - def test_load_checkpoint_with_state_dict(self, mocker, is_backbone_arg: bool, state_dict: OrderedDict): + def test_load_checkpoint_with_state_dict(self, mocker, state_dict: OrderedDict): """Test load_checkpoint with state_dict.""" mocker.patch( "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.freeze_networks" @@ -196,26 +180,8 @@ def test_load_checkpoint_with_state_dict(self, mocker, is_backbone_arg: bool, st sam_state_dict = sam.state_dict() for k, v in state_dict.items(): - if not is_backbone_arg: - k = k.replace("image_encoder", "image_encoder.backbone") assert k in sam_state_dict - assert v == sam_state_dict[k] - - @e2e_pytest_unit - def test_load_checkpoint_without_checkpoint(self, mocker): - """Test load_checkpoint without checkpoint.""" - mocker.patch( - "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.freeze_networks" - ) - mocker.patch( - "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.set_metrics" - ) - config = self.base_config.copy() - config.model.update(dict(checkpoint=None)) - - sam = SegmentAnything(config, state_dict=None) - - assert True + assert torch.all(v == sam_state_dict[k]) @e2e_pytest_unit def test_load_checkpoint_with_url(self, mocker): @@ -230,12 +196,16 @@ def test_load_checkpoint_with_url(self, mocker): mocker_load_state_dict = mocker.patch( "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_state_dict" ) + mocker_load_from_checkpoint = mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_from_checkpoint" + ) config = self.base_config.copy() config.model.update(dict(checkpoint="http://checkpoint")) sam = SegmentAnything(config, state_dict=None) + mocker_load_from_checkpoint.assert_not_called() mocker_load_state_dict_from_url.assert_called_once() mocker_load_state_dict.assert_called_once() @@ -265,9 +235,35 @@ def test_load_checkpoint_from_local_checkpoint(self, mocker, monkeypatch, checkp if checkpoint.endswith(".ckpt"): mocker_load_from_checkpoint.assert_called_once() + mocker_load_state_dict.assert_not_called() else: + mocker_load_from_checkpoint.assert_not_called() mocker_load_state_dict.assert_called_once() + @e2e_pytest_unit + def test_load_checkpoint_without_checkpoint(self, mocker): + """Test load_checkpoint without checkpoint.""" + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.freeze_networks" + ) + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.set_metrics" + ) + mocker_load_from_checkpoint = mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_from_checkpoint" + ) + mocker_load_state_dict = mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_state_dict" + ) + mocker_load_state_dict_from_url = mocker.patch("torch.hub.load_state_dict_from_url", return_value=OrderedDict()) + + config = self.base_config.copy() + sam = SegmentAnything(config, state_dict=None) + + mocker_load_from_checkpoint.assert_not_called() + mocker_load_state_dict_from_url.assert_called_once() + mocker_load_state_dict.assert_called_once() + @e2e_pytest_unit @pytest.mark.parametrize( "point_coords,point_labels,expected", @@ -346,24 +342,14 @@ def test_select_masks(self) -> None: assert masks[:, -1, :, :] == selected_mask assert iou_preds[:, -1] == selected_iou_pred - @e2e_pytest_unit - def test_mask_postprocessing(self, mocker) -> None: - """Test mask_postprocessing.""" - masks = torch.empty(1, 1, 2, 2) - orig_size = Tensor((8, 8)) - - results = SegmentAnything.mask_postprocessing(masks, 6, orig_size) - - assert results[0, 0].shape == tuple(orig_size) - @e2e_pytest_unit def test_forward_train(self) -> None: """Test forward.""" sam = SegmentAnything(config=self.base_config) - images = torch.zeros((1)) + images = torch.zeros((1, 3, 4, 4)) bboxes = torch.zeros((1)) - results = sam.forward_train(images=images, bboxes=bboxes, points=None) + results = sam.forward_train(images=images, bboxes=bboxes, points=[None]) pred_masks, ious = results assert len(bboxes) == len(pred_masks) == len(ious) @@ -390,7 +376,7 @@ def test_training_step(self, mocker, loss_type: str, expected: Tensor) -> None: images=torch.ones((1, 3, 4, 4)), gt_masks=[torch.Tensor([[0, 1, 1, 0] for _ in range(4)]).to(torch.int32)], bboxes=torch.Tensor([[0, 0, 1, 1]]), - points=[], + points=[None], padding=[[0, 0, 0, 0]], original_size=[[4, 4]], ) @@ -432,7 +418,7 @@ def test_validation_step(self, mocker) -> None: images=torch.ones((1, 3, 4, 4)), gt_masks=[torch.Tensor([[0, 1, 1, 0] for _ in range(4)]).to(torch.int32)], bboxes=torch.Tensor([[0, 0, 1, 1]]), - points=[], + points=[None], path=None, labels=None, padding=[0], @@ -483,7 +469,7 @@ def test_predict_step(self, mocker, return_logits: bool, expected: Tensor) -> No batch = dict( images=torch.zeros((1, 3, 4, 4)), bboxes=torch.Tensor([[0, 0, 1, 1]]), - points=[], + points=[None], path=None, labels=None, padding=[0], @@ -496,23 +482,39 @@ def test_predict_step(self, mocker, return_logits: bool, expected: Tensor) -> No @e2e_pytest_unit @pytest.mark.parametrize( - "input_size,original_size,padding,expected", + "input_size,original_size,expected", [ - ((6, 6), (8, 8), (0, 0, 0, 0), (8, 8)), - ((6, 6), (8, 8), (0, 0, 2, 2), (8, 8)), + (6, torch.tensor((8, 8)), (1, 8, 8)), + (6, torch.tensor((8, 8)), (1, 8, 8)), ], ) - def test_postprocess_masks( - self, input_size: Tuple[int], original_size: Tuple[int], padding: Tuple[int], expected: Tuple[int] - ) -> None: + def test_postprocess_masks(self, input_size: int, original_size: Tuple[int], expected: Tuple[int]) -> None: """Test postprocess_masks.""" sam = SegmentAnything(config=self.base_config) masks = torch.zeros((1, 1, 4, 4)) - results = sam.postprocess_masks(masks, input_size, padding, original_size) + results = sam.postprocess_masks(masks, input_size, original_size) assert results.shape[1:] == expected + @e2e_pytest_unit + @pytest.mark.parametrize( + "input_image_size,expected", + [ + (torch.tensor((2, 4)), torch.tensor((3, 6))), + (torch.tensor((4, 2)), torch.tensor((6, 3))), + ], + ) + def test_get_prepadded_size(self, input_image_size: Tensor, expected: Tensor) -> None: + """Test get_prepadded_size.""" + sam = SegmentAnything(config=self.base_config) + + longest_side = 6 + + results = sam.get_prepadded_size(input_image_size, longest_side) + + assert torch.all(results == expected) + @e2e_pytest_unit @pytest.mark.parametrize( "inputs,targets,expected", diff --git a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_zero_shot_segment_anything.py b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_zero_shot_segment_anything.py index 4437fdc1f42..5f1812adf86 100644 --- a/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_zero_shot_segment_anything.py +++ b/tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_zero_shot_segment_anything.py @@ -9,6 +9,8 @@ from collections import OrderedDict from tests.test_suite.e2e_test_system import e2e_pytest_unit import torch +import numpy as np +from torch import nn from omegaconf import DictConfig from otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything import ( @@ -27,111 +29,94 @@ class TestPromptGetter: - @pytest.fixture(autouse=True) - def setup(self) -> None: - self.prompt_getter = PromptGetter(image_size=3, downsizing=1) - - @e2e_pytest_unit - def test_initialize(self) -> None: - """Test initialize.""" - assert not self.prompt_getter.reference_feats - assert not self.prompt_getter.reference_prompts + @pytest.fixture + def prompt_getter(self) -> PromptGetter: + return PromptGetter(image_size=4, downsizing=1) @e2e_pytest_unit - def test_set_default_thresholds(self) -> None: + def test_set_default_thresholds(self, prompt_getter) -> None: """Test set_default_thresholds.""" - assert self.prompt_getter.default_threshold_reference == 0.3 - assert self.prompt_getter.default_threshold_target == 0.65 - - self.prompt_getter.set_default_thresholds(default_threshold_reference=0.5, default_threshold_target=0.7) - - assert self.prompt_getter.default_threshold_reference == 0.5 - assert self.prompt_getter.default_threshold_target == 0.7 + assert prompt_getter.default_threshold_reference == 0.3 + assert prompt_getter.default_threshold_target == 0.65 - @e2e_pytest_unit - def test_set_reference(self) -> None: - """Test set_reference.""" - self.prompt_getter.set_reference( - label=MockScoredLabel(label=1), - reference_feats=torch.ones((self.prompt_getter.image_size, self.prompt_getter.image_size)), - reference_prompts=torch.ones((self.prompt_getter.image_size, self.prompt_getter.image_size)), - ) - - assert self.prompt_getter.reference_feats[0].sum() == 0 - assert self.prompt_getter.reference_prompts[0].sum() == 0 - assert self.prompt_getter.reference_feats[1].sum() == 9 - assert self.prompt_getter.reference_prompts[1].sum() == 9 - - self.prompt_getter.set_reference( - label=MockScoredLabel(label=3), - reference_feats=torch.ones((self.prompt_getter.image_size, self.prompt_getter.image_size)), - reference_prompts=torch.ones((self.prompt_getter.image_size, self.prompt_getter.image_size)), - ) + prompt_getter.set_default_thresholds(default_threshold_reference=0.5, default_threshold_target=0.7) - assert self.prompt_getter.reference_feats[2].sum() == 0 - assert self.prompt_getter.reference_prompts[2].sum() == 0 - assert self.prompt_getter.reference_feats[3].sum() == 9 - assert self.prompt_getter.reference_prompts[3].sum() == 9 + assert prompt_getter.default_threshold_reference == 0.5 + assert prompt_getter.default_threshold_target == 0.7 @e2e_pytest_unit - def test_forward(self, mocker) -> None: + @pytest.mark.parametrize( + "result_point_selection", + [torch.tensor([[2, 2, 0.9], [1, 2, 0.8], [0, 2, 0.7], [2, 1, 0.6]]), torch.tensor([[-1, -1, -1]])], + ) + def test_forward(self, mocker, prompt_getter, result_point_selection: torch.Tensor) -> None: """Test forward.""" - mocker.patch.object( - self.prompt_getter, - "get_prompt_candidates", - return_value=(torch.tensor([[[0, 0, 0.5], [1, 1, 0.7]]]), torch.tensor([[[2, 2]]])), + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.ZeroShotSegmentAnything" ) + mocker.patch.object(prompt_getter, "_point_selection", return_value=(result_point_selection, torch.zeros(1, 2))) image_embeddings = torch.ones(1, 4, 4, 4) - self.prompt_getter.reference_feats = torch.rand(1, 1, 4) - original_size = torch.tensor((self.prompt_getter.image_size, self.prompt_getter.image_size), dtype=torch.int64) + reference_feat = torch.rand(1, 4) + original_size = torch.tensor([[prompt_getter.image_size, prompt_getter.image_size]], dtype=torch.int64) - total_points_scores, total_bg_coords = self.prompt_getter( - image_embeddings=image_embeddings, original_size=original_size + points_scores, bg_coords = prompt_getter( + image_embeddings=image_embeddings, reference_feat=reference_feat, original_size=original_size ) - assert total_points_scores.shape[0] == 1 - assert total_bg_coords.shape[0] == 1 + assert torch.all(points_scores == result_point_selection) + assert torch.all(bg_coords == torch.zeros(1, 2)) @e2e_pytest_unit - def test_get_prompt_candidates(self, mocker) -> None: + @pytest.mark.parametrize( + "result_point_selection", + [torch.tensor([[2, 2, 0.9], [1, 2, 0.8], [0, 2, 0.7], [2, 1, 0.6]]), torch.tensor([[-1, -1, -1]])], + ) + def test_get_prompt_candidates(self, mocker, prompt_getter, result_point_selection: torch.Tensor) -> None: """Test get_prompt_candidates.""" - mocker.patch( - "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.ZeroShotSegmentAnything" - ) - mocker.patch.object(self.prompt_getter, "_point_selection", return_value=("points_scores", "bg_coords")) + mocker.patch.object(prompt_getter, "_point_selection", return_value=(result_point_selection, torch.zeros(1, 2))) image_embeddings = torch.ones(1, 4, 4, 4) - self.prompt_getter.reference_feats = torch.rand(1, 1, 4) - label = torch.tensor([[0]], dtype=torch.int64) - original_size = torch.tensor( - [[self.prompt_getter.image_size, self.prompt_getter.image_size]], dtype=torch.int64 - ) - - points_scores, bg_coords = self.prompt_getter.get_prompt_candidates( - image_embeddings=image_embeddings, label=label, original_size=original_size + reference_feats = torch.rand(1, 1, 4) + used_indices = torch.as_tensor([[0]]) + original_size = torch.tensor((prompt_getter.image_size, prompt_getter.image_size), dtype=torch.int64) + + total_points_scores, total_bg_coords = prompt_getter.get_prompt_candidates( + image_embeddings=image_embeddings, + reference_feats=reference_feats, + used_indices=used_indices, + original_size=original_size, ) - assert points_scores == "points_scores" - assert bg_coords == "bg_coords" + assert total_points_scores[0].shape[0] == len(result_point_selection) + assert total_bg_coords[0].shape[0] == 1 @e2e_pytest_unit - def test_point_selection(self) -> None: + @pytest.mark.parametrize( + "mask_sim,expected", + [ + ( + torch.arange(0.1, 1.0, 0.1).reshape(3, 3), + torch.tensor([[2, 2, 0.9], [1, 2, 0.8], [0, 2, 0.7], [2, 1, 0.6]]), + ), + (torch.zeros(3, 3), torch.tensor([[-1, -1, -1]])), + ], + ) + def test_point_selection(self, prompt_getter, mask_sim: torch.Tensor, expected: torch.Tensor) -> None: """Test _point_selection.""" - mask_sim = torch.arange(0.1, 1.0, 0.1).reshape(self.prompt_getter.image_size, self.prompt_getter.image_size) - - points_scores, bg_coords = self.prompt_getter._point_selection( + points_scores, bg_coords = prompt_getter._point_selection( mask_sim=mask_sim, - original_size=torch.tensor([self.prompt_getter.image_size, self.prompt_getter.image_size]), + original_size=torch.tensor([prompt_getter.image_size, prompt_getter.image_size]), threshold=torch.tensor([[0.5]]), ) - assert torch.equal(points_scores, torch.tensor([[2, 2, 0.9], [1, 2, 0.8], [0, 2, 0.7], [2, 1, 0.6]])) - assert torch.equal(bg_coords, torch.tensor([[0, 0]])) + assert torch.equal(points_scores, expected) class TestZeroShotSegmentAnything: @pytest.fixture def set_zero_shot_segment_anything(self, monkeypatch): - def zero_shot_segment_anything(state_dict: Optional[OrderedDict] = None): + def zero_shot_segment_anything( + manual_config_update: Optional[Dict] = None, state_dict: Optional[OrderedDict] = None + ): monkeypatch.setattr( "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SAMImageEncoder", MockImageEncoder, @@ -140,7 +125,7 @@ def zero_shot_segment_anything(state_dict: Optional[OrderedDict] = None): "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SAMMaskDecoder", MockMaskDecoder, ) - return ZeroShotSegmentAnything(state_dict=state_dict) + return ZeroShotSegmentAnything(manual_config_update=manual_config_update, state_dict=state_dict) return zero_shot_segment_anything @@ -149,14 +134,16 @@ def zero_shot_segment_anything(state_dict: Optional[OrderedDict] = None): "state_dict", [ None, - { - "prompt_getter.reference_feats": "prompt_getter.reference_feats", - "prompt_getter.reference_prompts": "prompt_getter.reference_prompts", - }, + {}, ], ) - def test_init(self, set_zero_shot_segment_anything, state_dict: Dict[str, Any]) -> None: + def test_init(self, set_zero_shot_segment_anything, state_dict: Optional[Dict[str, Any]]) -> None: """Test __init__.""" + if state_dict is not None: + state_dict = set_zero_shot_segment_anything().state_dict() + state_dict.pop("reference_info.reference_feats") + state_dict.pop("reference_info.used_indices") + zero_shot_segment_anything = set_zero_shot_segment_anything(state_dict=state_dict) assert zero_shot_segment_anything.config.model.freeze_image_encoder @@ -164,8 +151,11 @@ def test_init(self, set_zero_shot_segment_anything, state_dict: Dict[str, Any]) assert zero_shot_segment_anything.config.model.freeze_mask_decoder if state_dict: - zero_shot_segment_anything.prompt_getter.reference_feats = "prompt_getter.reference_feats" - zero_shot_segment_anything.prompt_getter.reference_prompts = "prompt_getter.reference_prompts" + assert zero_shot_segment_anything.reference_info.reference_feats is not None + assert zero_shot_segment_anything.reference_info.used_indices is not None + + assert zero_shot_segment_anything.reference_info.reference_feats.dtype == torch.float32 + assert zero_shot_segment_anything.reference_info.used_indices.dtype == torch.int64 @e2e_pytest_unit def test_set_default_config(self, set_zero_shot_segment_anything) -> None: @@ -189,26 +179,31 @@ def test_set_default_config(self, set_zero_shot_segment_anything) -> None: @e2e_pytest_unit def test_learn(self, mocker, set_zero_shot_segment_anything) -> None: """Test learn.""" - zero_shot_segment_anything = set_zero_shot_segment_anything() + zero_shot_segment_anything = set_zero_shot_segment_anything(manual_config_update={"model.image_size": 4}) mocker.patch.object( zero_shot_segment_anything, "_predict_masks", - return_value=torch.tensor([[[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]), - ) - - processed_prompts = {MockScoredLabel(label=1, name="label"): [{"box": torch.tensor([[0, 0, 1, 1]])}]} - zero_shot_segment_anything.learn( - images=torch.ones((1, 3, 8, 8)), - processed_prompts=processed_prompts, - padding=(0, 0, 0, 0), - original_size=(8, 8), + return_value=torch.tensor([[[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]]]), ) + mocker.patch.object(zero_shot_segment_anything, "_generate_masked_features", return_value=torch.ones(1, 256)) - assert zero_shot_segment_anything.prompt_getter.reference_feats.shape == (2, 1, 2) - assert zero_shot_segment_anything.prompt_getter.reference_prompts.shape == (2, 8, 8) + batch = [ + { + "images": np.ones((4, 4, 3), dtype=np.uint8), + "gt_masks": np.ones((4, 4), dtype=np.uint8), + "bboxes": np.array([[0, 0, 1, 1]], dtype=np.float32), + "points": np.zeros((0, 2), dtype=np.float32), + "labels": {"bboxes": [MockScoredLabel(label=0, name="label")]}, + "original_size": np.array([4, 4], dtype=np.int64), + } + ] + zero_shot_segment_anything.learn(batch=batch, reset_feat=True) + + assert zero_shot_segment_anything.reference_info.reference_feats.shape == (1, 1, 256) + assert zero_shot_segment_anything.reference_info.used_indices == torch.as_tensor([0]) @e2e_pytest_unit - @pytest.mark.parametrize("expected", [[torch.ones((8, 8)) / 2, torch.tensor([0.0, 0.0, 0.5])]]) + @pytest.mark.parametrize("expected", [[torch.ones((4, 4)) / 2, torch.tensor([0.0, 0.0, 0.5])]]) def test_infer(self, monkeypatch, mocker, set_zero_shot_segment_anything, expected: torch.Tensor) -> None: """Test infer.""" monkeypatch.setattr( @@ -216,15 +211,26 @@ def test_infer(self, monkeypatch, mocker, set_zero_shot_segment_anything, expect MockPromptGetter, ) - zero_shot_segment_anything = set_zero_shot_segment_anything() - zero_shot_segment_anything.prompt_getter.reference_feats = torch.rand(1, 1, 4) - zero_shot_segment_anything.prompt_getter.reference_prompts = torch.zeros((8, 8)) + zero_shot_segment_anything = set_zero_shot_segment_anything(manual_config_update={"model.image_size": 4}) + reference_feats = nn.Parameter(torch.rand(1, 1, 256), requires_grad=False) + used_indices = nn.Parameter(torch.as_tensor([[0]], dtype=torch.int64), requires_grad=False) mocker.patch.object( - SegmentAnything, "forward", return_value=(torch.tensor([[0.1, 0.2, 0.5, 0.7]]), torch.ones(1, 4, 4, 4)) + SegmentAnything, + "forward", + return_value=(torch.ones(1, 4, 4, 4), torch.tensor([[0.1, 0.2, 0.5, 0.7]]), torch.ones(1, 4, 4, 4)), ) + batch = [ + { + "images": np.ones((4, 4, 3), dtype=np.uint8), + "gt_masks": np.ones((4, 4), dtype=np.uint8), + "original_size": np.array([4, 4], dtype=np.int64), + } + ] total_results = zero_shot_segment_anything.infer( - images=torch.ones((1, 3, 8, 8)), original_size=torch.tensor([[8, 8]], dtype=torch.int64) + batch=batch, + reference_feats=reference_feats, + used_indices=used_indices, ) for i, results in enumerate(total_results[0]): @@ -232,11 +238,116 @@ def test_infer(self, monkeypatch, mocker, set_zero_shot_segment_anything, expect assert torch.equal(result[0], expected[i]) @e2e_pytest_unit - @pytest.mark.parametrize("is_postprocess", [True, False]) - def test_predict_masks(self, mocker, set_zero_shot_segment_anything, is_postprocess: bool) -> None: + def test_inspect_overlapping_areas(self, mocker, set_zero_shot_segment_anything) -> None: + """Test _inspect_overlapping_areas.""" + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_checkpoint" + ) + zero_shot_segment_anything = set_zero_shot_segment_anything() + predicted_masks = { + 0: [ + torch.tensor( + [ + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + torch.tensor( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + torch.tensor( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + ], + ), + ], + 1: [ + torch.tensor( + [ + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + torch.tensor( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 1], + ], + ), + torch.tensor( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + ], + ), + torch.tensor( + [ + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + ], + } + used_points = { + 0: [ + torch.tensor([0, 0, 0.5]), # to be removed + torch.tensor([2, 2, 0.5]), + torch.tensor([1, 4, 0.5]), + ], + 1: [ + torch.tensor([3, 0, 0.5]), + torch.tensor([4, 4, 0.5]), + torch.tensor([1, 4, 0.3]), # to be removed + torch.tensor([0, 0, 0.7]), + ], + } + + zero_shot_segment_anything._inspect_overlapping_areas(predicted_masks, used_points, threshold_iou=0.5) + + assert len(predicted_masks[0]) == 2 + assert len(predicted_masks[1]) == 3 + assert all(torch.tensor([2, 2, 0.5]) == used_points[0][0]) + assert all(torch.tensor([0, 0, 0.7]) == used_points[1][2]) + + @e2e_pytest_unit + def test_predict_masks(self, mocker, set_zero_shot_segment_anything) -> None: """Test _predict_masks.""" mocker.patch.object( - SegmentAnything, "forward", return_value=(torch.tensor([[0.1, 0.2, 0.5, 0.7]]), torch.ones(1, 4, 4, 4)) + SegmentAnything, + "forward", + return_value=(torch.ones(1, 4, 8, 8), torch.tensor([[0.1, 0.2, 0.5, 0.7]]), torch.ones(1, 4, 4, 4)), ) zero_shot_segment_anything = set_zero_shot_segment_anything() @@ -246,28 +357,27 @@ def test_predict_masks(self, mocker, set_zero_shot_segment_anything, is_postproc image_embeddings=torch.rand(1), point_coords=torch.rand(1, 2, 2), point_labels=torch.randint(low=0, high=2, size=(1, 2)), - original_size=torch.tensor((8, 8), dtype=torch.int64), + original_size=torch.tensor([8, 8], dtype=torch.int64), ) assert mask.shape == (8, 8) @e2e_pytest_unit def test_preprocess_prompts(self, set_zero_shot_segment_anything) -> None: - """Test _preprocess_prompts. - - TODO (sungchul) - - get inputs grouped as label and prompts - - use points and annotations. - """ + """Test _preprocess_prompts.""" zero_shot_segment_anything = set_zero_shot_segment_anything() - bboxes = [torch.tensor([0, 0, 1, 1])] - labels = [MockScoredLabel(label=1)] - processed_prompts = zero_shot_segment_anything._preprocess_prompts( - bboxes=bboxes, - labels=labels, - ) + transformed_batch = { + "bboxes": torch.tensor([[0, 0, 1, 1]]), + "points": torch.tensor([[2, 2]]), + "labels": {"bboxes": [MockScoredLabel(label=1)], "points": [MockScoredLabel(label=1)]}, + } + processed_prompts = zero_shot_segment_anything._preprocess_prompts(transformed_batch) - # processed_prompts = {labels[0]: [{"box": torch.tensor([[0, 0, 1, 1]])}]} - assert torch.equal(processed_prompts[labels[0]][0].get("box")[0], bboxes[0]) + for prompts in processed_prompts.values(): + for prompt in prompts: + if "bboxes" in prompt: + prompt["bboxes"]["point_coords"].shape == (1, 2, 2) + elif "points" in prompt: + prompt["points"]["point_coords"].shape == (1, 1, 2) @e2e_pytest_unit def test_generate_masked_features(self, set_zero_shot_segment_anything) -> None: @@ -283,12 +393,12 @@ def test_generate_masked_features(self, set_zero_shot_segment_anything) -> None: assert masked_feat.shape == (1, 1) @e2e_pytest_unit - def test_preprocess_masks(self, set_zero_shot_segment_anything) -> None: - """Test _preprocess_masks.""" + def test_pad_to_square(self, set_zero_shot_segment_anything) -> None: + """Test _pad_to_square.""" zero_shot_segment_anything = set_zero_shot_segment_anything() zero_shot_segment_anything.config.model.image_size = 16 - result = zero_shot_segment_anything._preprocess_masks(x=torch.ones(1, 1, 8, 8)) + result = zero_shot_segment_anything._pad_to_square(x=torch.ones(1, 1, 8, 8)) assert result[:8, :8].sum() == 8**2 assert result[:8, 8:].sum() == 0 @@ -297,47 +407,97 @@ def test_preprocess_masks(self, set_zero_shot_segment_anything) -> None: @e2e_pytest_unit @pytest.mark.parametrize( - "logits,expected", + "masks,logits,expected", [ - (torch.ones(1, 4, 4, 4), torch.ones(4, 4, dtype=torch.bool)), - (torch.zeros(1, 4, 4, 4), torch.zeros(4, 4, dtype=torch.bool)), + (torch.ones(1, 4, 8, 8), torch.ones(1, 4, 4, 4), torch.ones(8, 8)), + (torch.zeros(1, 4, 8, 8), torch.zeros(1, 4, 4, 4), torch.zeros(8, 8)), ], ) def test_postprocess_masks( - self, set_zero_shot_segment_anything, logits: torch.Tensor, expected: torch.Tensor + self, set_zero_shot_segment_anything, masks: torch.Tensor, logits: torch.Tensor, expected: torch.Tensor ) -> None: """Test _postprocess_masks.""" zero_shot_segment_anything = set_zero_shot_segment_anything() zero_shot_segment_anything.config.model.image_size = 4 scores = torch.tensor([[0.0, 0.1, 0.2, 0.3]]) - original_size = torch.tensor([4, 4], dtype=torch.int64) - _, result = zero_shot_segment_anything._postprocess_masks(logits, scores, original_size) + _, result = zero_shot_segment_anything._postprocess_masks(masks, logits, scores) assert torch.equal(result, expected) @e2e_pytest_unit - @pytest.mark.parametrize("use_only_background", [True, False]) - def test_merge_prompts(self, set_zero_shot_segment_anything, use_only_background: bool) -> None: - """Test _merge_prompts.""" + def test_find_latest_reference_info(self, mocker, set_zero_shot_segment_anything): + """Test _find_latest_reference_info.""" zero_shot_segment_anything = set_zero_shot_segment_anything() + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.os.path.isdir", + return_value=True, + ) - input_prompts = {"point_coords": torch.tensor([1]), "point_labels": torch.tensor([1])} - processed_prompts = { - MockScoredLabel(label=0): [{"point_coords": torch.tensor([0]), "point_labels": torch.tensor([0])}], - MockScoredLabel(label=2): [{"point_coords": torch.tensor([2]), "point_labels": torch.tensor([1])}], - } + # there are some saved reference info + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.os.listdir", + return_value=["1", "2"], + ) + results = zero_shot_segment_anything._find_latest_reference_info() + assert results == "2" - merged_input_prompts = zero_shot_segment_anything._merge_prompts( - label=MockScoredLabel(label=1), - input_prompts=input_prompts, - processed_prompts=processed_prompts, - use_only_background=use_only_background, + # there are no saved reference info + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.os.listdir", + return_value=[], ) + results = zero_shot_segment_anything._find_latest_reference_info() + assert results is None + + @e2e_pytest_unit + def test_on_predict_start(self, mocker, set_zero_shot_segment_anything): + """Test on_predict_start.""" + zero_shot_segment_anything = set_zero_shot_segment_anything() + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.os.path.isdir", + return_value=True, + ) + + # get previously saved reference info + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.os.listdir", + return_value=["1", "2"], + ) + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.torch.load", + return_value=torch.nn.ParameterDict( + {"reference_feats": torch.zeros((1, 1, 256)), "used_indices": torch.tensor([[0.0]])} + ), + ) + mocker.patch("builtins.open", return_value="Mocked data") + + zero_shot_segment_anything.on_predict_start() + assert isinstance(zero_shot_segment_anything.reference_info, torch.nn.ParameterDict) + assert zero_shot_segment_anything.reference_info["reference_feats"].shape == (1, 1, 256) + assert zero_shot_segment_anything.reference_info["used_indices"].shape == (1, 1) + + # no saved reference info + mocker.patch( + "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.zero_shot_segment_anything.os.listdir", + return_value=[], + ) + + zero_shot_segment_anything.set_empty_reference_info() + zero_shot_segment_anything.on_predict_start() + + assert zero_shot_segment_anything.reference_info["reference_feats"].shape == (0,) + assert zero_shot_segment_anything.reference_info["used_indices"].shape == (1, 0) + + @e2e_pytest_unit + def test_expand_reference_info(self, set_zero_shot_segment_anything): + """Test expand_reference_info.""" + zero_shot_segment_anything = set_zero_shot_segment_anything() + zero_shot_segment_anything.reference_info["reference_feats"] = torch.ones((3, 2, 2)) + new_largest_label = 5 + + zero_shot_segment_anything.expand_reference_info(new_largest_label) - if use_only_background: - assert torch.equal(merged_input_prompts.get("point_coords"), torch.tensor([1, 0])) - assert torch.equal(merged_input_prompts.get("point_labels"), torch.tensor([1, 0])) - else: - assert torch.equal(merged_input_prompts.get("point_coords"), torch.tensor([1, 0, 2])) - assert torch.equal(merged_input_prompts.get("point_labels"), torch.tensor([1, 0, 0])) + assert zero_shot_segment_anything.reference_info["reference_feats"].shape == (6, 2, 2) + assert torch.all(zero_shot_segment_anything.reference_info["reference_feats"][:3] == 1.0) + assert torch.all(zero_shot_segment_anything.reference_info["reference_feats"][3:] == 0.0) diff --git a/tests/unit/algorithms/visual_prompting/tasks/test_inference.py b/tests/unit/algorithms/visual_prompting/tasks/test_inference.py index acd9d0c48ca..8b330131b0f 100644 --- a/tests/unit/algorithms/visual_prompting/tasks/test_inference.py +++ b/tests/unit/algorithms/visual_prompting/tasks/test_inference.py @@ -343,6 +343,7 @@ def test_export_to_onnx(self): }, "visual_prompting_prompt_getter": { "image_embeddings": np.random.randn(1, embed_dim, *embed_size).astype(dtype=np.float32), + "reference_feat": np.random.randn(1, 256).astype(dtype=np.float32), "original_size": np.random.randint(low=0, high=image_size * 2, size=(1, 2), dtype=np.int64), "threshold": np.array([[0.1]], dtype=np.float32), "num_bg_points": np.random.randint(low=1, high=image_size, size=(1, 1), dtype=np.int64), @@ -353,12 +354,13 @@ def test_export_to_onnx(self): "point_labels": np.random.randint(low=0, high=4, size=(1, 2)).astype(np.float32), "mask_input": np.random.randn(1, 1, *mask_input_size).astype(np.float32), "has_mask_input": np.array([[1]], dtype=np.float32), + "orig_size": np.random.randint(low=256, high=2048, size=(1, 2)).astype(np.int64), }, } onnx_outputs = { "visual_prompting_image_encoder": ["image_embeddings"], - "visual_prompting_prompt_getter": ["total_points_scores", "total_bg_coords"], - "visual_prompting_decoder": ["iou_predictions", "low_res_masks"], + "visual_prompting_prompt_getter": ["points_scores", "bg_coords"], + "visual_prompting_decoder": ["upscaled_masks", "iou_predictions", "low_res_masks"], } onnx_rt_models = { @@ -378,10 +380,13 @@ def test_save_model(self, mocker): mocker_otx_model = mocker.patch("otx.api.entities.model.ModelEntity") mocker_io_bytes_io = mocker.patch("io.BytesIO") mocker_torch_save = mocker.patch("torch.save") + mocker.patch.object( + self.zero_shot_task.model, + "state_dict", + return_value={"reference_info.reference_feats": None, "reference_info.used_indices": None}, + ) - self.zero_shot_task.model.prompt_getter = mocker.MagicMock() - self.zero_shot_task.model.prompt_getter.reference_feats.return_value = "reference_feats" - self.zero_shot_task.model.prompt_getter.reference_prompts.return_value = "reference_prompts" + self.zero_shot_task.model.reference_info = "reference_info" self.zero_shot_task.save_model(mocker_otx_model) diff --git a/tests/unit/algorithms/visual_prompting/tasks/test_openvino.py b/tests/unit/algorithms/visual_prompting/tasks/test_openvino.py index 8711eb1705a..8dab6141cab 100644 --- a/tests/unit/algorithms/visual_prompting/tasks/test_openvino.py +++ b/tests/unit/algorithms/visual_prompting/tasks/test_openvino.py @@ -49,6 +49,7 @@ generate_visual_prompting_dataset, init_environment, ) +from tests.unit.algorithms.visual_prompting.test_helpers import MockScoredLabel class TestOpenVINOVisualPromptingInferencer: @@ -138,7 +139,7 @@ def test_predict(self, mocker): return_value={"image_embeddings": np.empty((4, 2, 2))}, ) mocker_forward_decoder = mocker.patch.object( - OpenVINOVisualPromptingInferencer, "forward_decoder", return_value=None + OpenVINOVisualPromptingInferencer, "forward_decoder", return_value={"iou_predictions": 0.1} ) mocker_post_process = mocker.patch.object( OpenVINOVisualPromptingInferencer, "post_process", return_value=(self.fake_annotation, None, None) @@ -193,23 +194,74 @@ def setup(self, mocker): visual_prompting_hparams = self.task_environment.get_hyper_parameters(VisualPromptingBaseConfig) label_schema = self.task_environment.label_schema - self.visual_prompting_ov_inferencer = OpenVINOZeroShotVisualPromptingInferencer( + self.zero_shot_visual_prompting_ov_inferencer = OpenVINOZeroShotVisualPromptingInferencer( visual_prompting_hparams, label_schema, {"image_encoder": "", "prompt_getter": "", "decoder": ""}, {"image_encoder": "", "prompt_getter": "", "decoder": ""}, ) - self.visual_prompting_ov_inferencer.model["decoder"] = mocker.patch( - "otx.algorithms.visual_prompting.tasks.openvino.model_wrappers.Decoder", autospec=True + self.zero_shot_visual_prompting_ov_inferencer.model["decoder"] = mocker.patch( + "otx.algorithms.visual_prompting.tasks.openvino.model_wrappers.Decoder", + autospec=True, ) - self.visual_prompting_ov_inferencer.model["decoder"]._apply_coords.return_value = np.array([[1, 1]]) + self.zero_shot_visual_prompting_ov_inferencer.model["decoder"].mask_threshold = 0.3 + self.zero_shot_visual_prompting_ov_inferencer.model["decoder"]._apply_coords.return_value = np.array([[1, 1]]) + self.zero_shot_visual_prompting_ov_inferencer.model["decoder"].output_blob_name = "upscaled_masks" + + @e2e_pytest_unit + def test_learn(self, mocker): + """Test learn.""" + mocker_pre_process = mocker.patch.object( + OpenVINOVisualPromptingInferencer, + "pre_process", + return_value=( + torch.zeros((1, 3, 2, 2)), + {"original_shape": np.array((4, 4))}, + [ + { + "point_coords": [np.array([[[1, 1], [2, 2]]])], + "point_labels": [1, 2], + "label": MockScoredLabel(label=0, name="fake"), + "orig_size": (4, 4), + } + ], + ), + ) + mocker_forward_image_encoder = mocker.patch.object( + OpenVINOZeroShotVisualPromptingInferencer, + "forward_image_encoder", + return_value={"image_embeddings": np.empty((4, 2, 2))}, + ) + mocker_generate_masked_features = mocker.patch.object( + OpenVINOZeroShotVisualPromptingInferencer, "_generate_masked_features", return_value=torch.ones(1, 256) + ) + + self.zero_shot_visual_prompting_ov_inferencer.model["decoder"].infer_sync.return_value = { + "upscaled_masks": np.ones((1, 4, 4, 4), dtype=np.bool), + "iou_predictions": np.array([[0.9, 0.7, 0.9, 0.8]]), + "low_res_masks": np.ones((1, 4, 2, 2)), + } + mocker_pickle_dump = mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.pickle.dump") + mocker.patch("builtins.open", return_value="Mocked data") + self.zero_shot_visual_prompting_ov_inferencer.model["prompt_getter"].default_threshold_reference = 0.3 + + fake_input = mocker.Mock(spec=DatasetItemEntity) + results = self.zero_shot_visual_prompting_ov_inferencer.learn(fake_input, reset_feat=True) + + assert results[0]["reference_feats"].shape == (1, 1, 256) + assert results[0]["used_indices"] == np.array([[0]]) + assert np.all(results[1] == np.ones((1, 4, 4))) + mocker_pre_process.assert_called_once() + mocker_forward_image_encoder.assert_called_once() + mocker_generate_masked_features.assert_called_once() + mocker_pickle_dump.assert_called_once() @e2e_pytest_unit def test_predict(self, mocker): """Test predict.""" mocker_pre_process = mocker.patch.object( OpenVINOZeroShotVisualPromptingInferencer, - "pre_process", + "pre_process_image_encoder", return_value=(torch.zeros((1, 3, 2, 2)), {"original_shape": (4, 4, 1)}), ) mocker_forward = mocker.patch.object( @@ -220,37 +272,65 @@ def test_predict(self, mocker): mocker_forward_decoder = mocker.patch.object( OpenVINOZeroShotVisualPromptingInferencer, "forward_prompt_getter", - return_value={"total_points_scores": np.array([[[1, 1, 1]]]), "total_bg_coords": np.array([[[2, 2]]])}, + return_value=({0: np.array([[1, 1, 1]])}, {0: np.array([[2, 2]])}), ) mocker_forward_decoder = mocker.patch.object( - OpenVINOZeroShotVisualPromptingInferencer, "forward_decoder", return_value=None + OpenVINOZeroShotVisualPromptingInferencer, "forward_decoder", return_value={"upscaled_masks": None} ) mocker_post_process = mocker.patch.object( OpenVINOZeroShotVisualPromptingInferencer, "post_process", return_value=(self.fake_annotation, None, None) ) + self.zero_shot_visual_prompting_ov_inferencer.reference_feats = np.random.rand(1, 1, 1) + self.zero_shot_visual_prompting_ov_inferencer.used_indices = np.array([[0]]) fake_input = mocker.Mock(spec=DatasetItemEntity) - returned_value = self.visual_prompting_ov_inferencer.predict(fake_input) + results = self.zero_shot_visual_prompting_ov_inferencer.predict(fake_input) mocker_pre_process.assert_called_once() mocker_forward.assert_called_once() mocker_forward_decoder.assert_called_once() mocker_post_process.assert_called_once() - assert returned_value == self.fake_annotation + assert results == self.fake_annotation + + @e2e_pytest_unit + def test_forward_prompt_getter(self): + """Test forward_prompt_getter.""" + self.zero_shot_visual_prompting_ov_inferencer.model["prompt_getter"].infer_sync.return_value = { + "points_scores": np.array([[1, 1, 0.5]]), + "bg_coords": np.array([[0, 0]]), + } + + total_points_scores, total_bg_coords = self.zero_shot_visual_prompting_ov_inferencer.forward_prompt_getter( + image_embeddings={"image_embeddings": np.empty((4, 2, 2))}, + reference_feats=np.random.rand(1, 1, 1), + used_indices=np.array([[0]]), + original_size=np.array([4, 4]), + ) + + assert np.all(total_points_scores[0] == np.array([[1, 1, 0.5]])) + assert np.all(total_bg_coords[0] == np.array([[0, 0]])) @e2e_pytest_unit @pytest.mark.parametrize( "postprocess_output,infer_sync_output,expected", [ ( - (np.ones((1, 1)), np.ones((3, 3)), 0.9), - {"iou_predictions": np.array([[0.9]]), "low_res_masks": np.ones((1, 1, 2, 2))}, - {"iou_predictions": np.array([[0.9]]), "low_res_masks": np.ones((1, 1, 2, 2))}, + (np.ones((1, 1)), np.ones((3, 3))), + { + "upscaled_masks": np.ones((3, 3)), + "iou_predictions": np.array([[0.9]]), + "low_res_masks": np.ones((1, 1, 2, 2)), + }, + {"upscaled_masks": np.ones((3, 3))}, ), ( - (np.zeros((2, 2)), np.zeros((3, 3)), 0.0), - {"iou_predictions": np.array([[0.9]]), "low_res_masks": np.ones((1, 1, 2, 2))}, - {"iou_predictions": 0.0, "low_res_masks": np.zeros((2, 2))}, + (np.zeros((2, 2)), np.zeros((3, 3))), + { + "upscaled_masks": np.zeros((3, 3)), + "iou_predictions": np.array([[0.9]]), + "low_res_masks": np.ones((1, 1, 2, 2)), + }, + {"upscaled_masks": np.zeros((3, 3))}, ), ], ) @@ -263,16 +343,18 @@ def test_forward_decoder( ): """Test forward_decoder.""" mocker.patch.object( - self.visual_prompting_ov_inferencer.model["decoder"], "infer_sync", return_value=infer_sync_output + self.zero_shot_visual_prompting_ov_inferencer.model["decoder"], "infer_sync", return_value=infer_sync_output ) mocker.patch.object( - self.visual_prompting_ov_inferencer.model["decoder"], + self.zero_shot_visual_prompting_ov_inferencer.model["decoder"], "_apply_coords", return_value=np.array([[[1, 1]]], dtype=np.float32), ) - mocker.patch.object(self.visual_prompting_ov_inferencer, "_postprocess_masks", return_value=postprocess_output) + mocker.patch.object( + self.zero_shot_visual_prompting_ov_inferencer, "_postprocess_masks", return_value=postprocess_output + ) - result = self.visual_prompting_ov_inferencer.forward_decoder( + result = self.zero_shot_visual_prompting_ov_inferencer.forward_decoder( inputs={ "image_embeddings": np.empty((1, 4, 2, 2)), "point_coords": np.array([[[1, 1]]], dtype=np.float32), @@ -281,45 +363,217 @@ def test_forward_decoder( original_size=np.array([3, 3]), ) - assert np.all(result["iou_predictions"] == expected["iou_predictions"]) - assert np.all(result["low_res_masks"] == expected["low_res_masks"]) + assert np.all(result["upscaled_masks"] == expected["upscaled_masks"]) @e2e_pytest_unit @pytest.mark.parametrize( - "high_res_masks,expected_masks,expected_scores", + "masks,expected_masks", [ ( - np.repeat(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])[..., None], 4, axis=-1), + np.repeat(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])[None], 4, axis=0)[None], np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.bool_), - 0.9, ), ( np.concatenate( ( - np.repeat(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])[..., None], 3, axis=-1), - np.zeros((3, 3, 1)), + np.repeat(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])[None], 3, axis=0)[None], + np.zeros((1, 1, 3, 3)), ), - axis=-1, + axis=1, ), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.bool_), - 0.8, ), - (np.zeros((3, 3, 4)), np.zeros((3, 3)), 0.0), + (np.zeros((1, 4, 3, 3)), np.zeros((3, 3))), ], ) - def test_postprocess_masks(self, high_res_masks: np.ndarray, expected_masks: np.ndarray, expected_scores: float): + def test_postprocess_masks(self, masks: np.ndarray, expected_masks: np.ndarray): """Test _postprocess_masks.""" - self.visual_prompting_ov_inferencer.model["decoder"].resize_and_crop.return_value = high_res_masks - self.visual_prompting_ov_inferencer.model["decoder"].mask_threshold = 0.0 - self.visual_prompting_ov_inferencer.model["decoder"].image_size = 3 + self.zero_shot_visual_prompting_ov_inferencer.model["decoder"].mask_threshold = 0.0 + self.zero_shot_visual_prompting_ov_inferencer.model["decoder"].image_size = 3 - _, result_masks, result_scores = self.visual_prompting_ov_inferencer._postprocess_masks( - logits=np.empty((1, 4, 2, 2)), scores=np.array([[0.5, 0.7, 0.8, 0.9]]), original_size=np.array([3, 3]) + _, result_masks = self.zero_shot_visual_prompting_ov_inferencer._postprocess_masks( + masks=masks, logits=np.empty((1, 4, 2, 2)), scores=np.array([[0.5, 0.7, 0.8, 0.9]]) ) assert result_masks.shape == (3, 3) assert np.all(result_masks == expected_masks) - assert result_scores == expected_scores + + @e2e_pytest_unit + def test_inspect_overlapping_areas(self) -> None: + """Test _inspect_overlapping_areas.""" + predicted_masks = { + 0: [ + np.array( + [ + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + ], + ), + ], + 1: [ + np.array( + [ + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 1, 1], + ], + ), + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + [0, 1, 1, 0, 0, 0], + ], + ), + np.array( + [ + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + ), + ], + } + used_points = { + 0: [ + np.array([0, 0, 0.5]), # to be removed + np.array([2, 2, 0.5]), + np.array([1, 4, 0.5]), + ], + 1: [ + np.array([3, 0, 0.5]), + np.array([4, 4, 0.5]), + np.array([1, 4, 0.3]), # to be removed + np.array([0, 0, 0.7]), + ], + } + + self.zero_shot_visual_prompting_ov_inferencer._inspect_overlapping_areas( + predicted_masks, used_points, threshold_iou=0.5 + ) + + assert len(predicted_masks[0]) == 2 + assert len(predicted_masks[1]) == 3 + assert all(np.array([2, 2, 0.5]) == used_points[0][0]) + assert all(np.array([0, 0, 0.7]) == used_points[1][2]) + + @e2e_pytest_unit + def test_find_latest_reference_info(self, mocker): + """Test _find_latest_reference_info.""" + mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.os.path.isdir", return_value=True) + + # there are some saved reference info + mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.os.listdir", return_value=["1", "2"]) + results = self.zero_shot_visual_prompting_ov_inferencer._find_latest_reference_info() + assert results == "2" + + # there are no saved reference info + mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.os.listdir", return_value=[]) + results = self.zero_shot_visual_prompting_ov_inferencer._find_latest_reference_info() + assert results is None + + @e2e_pytest_unit + def test_get_reference_info(self, mocker): + """Test _get_reference_info.""" + mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.os.path.isdir", return_value=True) + + # get previously saved reference info + mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.os.listdir", return_value=["1", "2"]) + mocker.patch( + "otx.algorithms.visual_prompting.tasks.openvino.pickle.load", + return_value={"reference_feats": 1, "used_indices": 2}, + ) + mocker.patch("builtins.open", return_value="Mocked data") + + results = self.zero_shot_visual_prompting_ov_inferencer._get_reference_info() + assert results == (1, 2) + + # no saved reference info + mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.os.listdir", return_value=[]) + + results = self.zero_shot_visual_prompting_ov_inferencer._get_reference_info() + assert results == (None, None) + + @e2e_pytest_unit + def test_expand_reference_info(self): + """Test expand_reference_info.""" + self.zero_shot_visual_prompting_ov_inferencer.reference_feats = np.ones((3, 2, 2)) + new_largest_label = 5 + + self.zero_shot_visual_prompting_ov_inferencer.expand_reference_info(new_largest_label) + + assert self.zero_shot_visual_prompting_ov_inferencer.reference_feats.shape == (6, 2, 2) + assert np.all(self.zero_shot_visual_prompting_ov_inferencer.reference_feats[:3] == 1.0) + assert np.all(self.zero_shot_visual_prompting_ov_inferencer.reference_feats[3:] == 0.0) + + @e2e_pytest_unit + def test_generate_masked_features(self) -> None: + """Test _generate_masked_features.""" + self.zero_shot_visual_prompting_ov_inferencer.model["image_encoder"].image_size = 16 + feats = np.random.rand(8, 8, 1) + masks = np.zeros((16, 16), dtype=np.float32) + masks[4:12, 4:12] = 1.0 + + masked_feat = self.zero_shot_visual_prompting_ov_inferencer._generate_masked_features( + feats=feats, masks=masks, threshold_mask=0.3 + ) + + assert masked_feat.shape == (1, 1) + + @e2e_pytest_unit + def test_pad_to_square(self) -> None: + """Test _pad_to_square.""" + self.zero_shot_visual_prompting_ov_inferencer.model["image_encoder"].image_size = 16 + + result = self.zero_shot_visual_prompting_ov_inferencer._pad_to_square(x=np.ones((8, 8))) + + assert result[:8, :8].sum() == 8**2 + assert result[:8, 8:].sum() == 0 + assert result[8:, :8].sum() == 0 + assert result[8:, 8:].sum() == 0 class TestOTXOpenVinoDataLoader: @@ -365,7 +619,7 @@ def test_getitem(self, mocker, load_dataloader, module_name: str): self.mocker_read_model.assert_called_once() self.mocker_compile_model.assert_called_once() assert "label" not in results - assert "orig_size" not in results + assert "orig_size" in results assert "image_embeddings" in results @@ -376,7 +630,12 @@ def _load_dataloader(module_name: str, output_model: Optional[ModelEntity] = Non dataset = generate_visual_prompting_dataset() dataset = dataset.get_subset(Subset.TRAINING) return OTXZeroShotOpenVinoDataLoader( - dataset, self.mocker_inferencer, module_name, output_model=output_model + dataset, + self.mocker_inferencer, + module_name, + output_model=output_model, + reference_feats=np.zeros((1, 1, 1)), + used_indices=np.array([[0]]), ) return _load_dataloader @@ -402,17 +661,14 @@ def test_getitem(self, mocker, load_dataloader, module_name: str): setattr(dataloader, "target_length", 8) mocker.patch.object( dataloader.inferencer, - "pre_process", + "pre_process_image_encoder", return_value=({"images": np.zeros((1, 3, 4, 4), dtype=np.uint8)}, {"original_shape": (4, 4)}), ) if module_name == "decoder": mocker.patch.object( - dataloader, - "prompt_getter", - return_value={ - "total_points_scores": [np.array([[0, 0, 0.5]])], - "total_bg_coords": [np.array([[1, 1]])], - }, + dataloader.inferencer, + "forward_prompt_getter", + return_value=({0: np.array([[0, 0, 0.5]])}, {0: np.array([[1, 1]])}), ) results = dataloader.__getitem__(0) @@ -586,7 +842,52 @@ def setup(self, mocker, otx_model): mocker.patch.object( OpenVINOZeroShotVisualPromptingTask, "load_inferencer", return_value=visual_prompting_ov_inferencer ) - self.visual_prompting_ov_task = OpenVINOZeroShotVisualPromptingTask(task_environment=self.task_environment) + self.zero_shot_visual_prompting_ov_task = OpenVINOZeroShotVisualPromptingTask( + task_environment=self.task_environment + ) + + @e2e_pytest_unit + def test_infer_without_reference_info(self): + """Test infer without reference_info.""" + dataset = generate_visual_prompting_dataset() + + updated_dataset = self.zero_shot_visual_prompting_ov_task.infer( + dataset, InferenceParameters(enable_async_inference=False) + ) + + for updated in updated_dataset: + assert len(updated.annotation_scene.annotations) == 0 + + @e2e_pytest_unit + def test_infer_with_reference_info(self, mocker): + """Test infer with reference_info.""" + fake_annotation = [ + Annotation( + Polygon(points=[Point(0, 0)]), + id=0, + labels=[ScoredLabel(LabelEntity(name="fake", domain="VISUALPROMPTING"), probability=1.0)], + ) + ] + + mocker_predict = mocker.patch.object( + OpenVINOZeroShotVisualPromptingInferencer, "predict", return_value=fake_annotation + ) + mocker.patch.object(ShapeFactory, "shape_produces_valid_crop", return_value=True) + mocker.patch.object( + self.zero_shot_visual_prompting_ov_task.inferencer, "_get_reference_info", return_value=({}, {}) + ) + + dataset = generate_visual_prompting_dataset() + + updated_dataset = self.zero_shot_visual_prompting_ov_task.infer( + dataset, InferenceParameters(enable_async_inference=False) + ) + + for updated in updated_dataset: + assert updated.annotation_scene.contains_any([LabelEntity(name="fake", domain="VISUALPROMPTING")]) + + mocker_predict.assert_called() + assert mocker_predict.call_count == len(updated_dataset) @e2e_pytest_unit def test_optimize(self, mocker): @@ -601,43 +902,53 @@ def patch_save_model(model, output_xml): dataset = generate_visual_prompting_dataset() output_model = deepcopy(self.task_environment.model) - self.visual_prompting_ov_task.model.set_data("visual_prompting_image_encoder.xml", b"image_encoder_xml") - self.visual_prompting_ov_task.model.set_data("visual_prompting_image_encoder.bin", b"image_encoder_bin") - self.visual_prompting_ov_task.model.set_data("visual_prompting_prompt_getter.xml", b"prompt_getter_xml") - self.visual_prompting_ov_task.model.set_data("visual_prompting_prompt_getter.bin", b"prompt_getter_bin") - self.visual_prompting_ov_task.model.set_data("visual_prompting_decoder.xml", b"decoder_xml") - self.visual_prompting_ov_task.model.set_data("visual_prompting_decoder.bin", b"decoder_bin") + self.zero_shot_visual_prompting_ov_task.model.set_data( + "visual_prompting_image_encoder.xml", b"image_encoder_xml" + ) + self.zero_shot_visual_prompting_ov_task.model.set_data( + "visual_prompting_image_encoder.bin", b"image_encoder_bin" + ) + self.zero_shot_visual_prompting_ov_task.model.set_data( + "visual_prompting_prompt_getter.xml", b"prompt_getter_xml" + ) + self.zero_shot_visual_prompting_ov_task.model.set_data( + "visual_prompting_prompt_getter.bin", b"prompt_getter_bin" + ) + self.zero_shot_visual_prompting_ov_task.model.set_data("visual_prompting_decoder.xml", b"decoder_xml") + self.zero_shot_visual_prompting_ov_task.model.set_data("visual_prompting_decoder.bin", b"decoder_bin") mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.ov.Core.read_model", autospec=True) mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.ov.save_model", new=patch_save_model) mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.ov.Core.compile_model") fake_quantize = mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.nncf.quantize", autospec=True) - self.visual_prompting_ov_task.optimize(OptimizationType.POT, dataset=dataset, output_model=output_model) + self.zero_shot_visual_prompting_ov_task.optimize( + OptimizationType.POT, dataset=dataset, output_model=output_model + ) fake_quantize.assert_called() assert fake_quantize.call_count == 3 assert ( - self.visual_prompting_ov_task.model.get_data("visual_prompting_image_encoder.xml") + self.zero_shot_visual_prompting_ov_task.model.get_data("visual_prompting_image_encoder.xml") == b"compressed_visual_prompting_image_encoder.xml" ) assert ( - self.visual_prompting_ov_task.model.get_data("visual_prompting_image_encoder.bin") + self.zero_shot_visual_prompting_ov_task.model.get_data("visual_prompting_image_encoder.bin") == b"compressed_visual_prompting_image_encoder.bin" ) assert ( - self.visual_prompting_ov_task.model.get_data("visual_prompting_prompt_getter.xml") + self.zero_shot_visual_prompting_ov_task.model.get_data("visual_prompting_prompt_getter.xml") == b"compressed_visual_prompting_prompt_getter.xml" ) assert ( - self.visual_prompting_ov_task.model.get_data("visual_prompting_prompt_getter.bin") + self.zero_shot_visual_prompting_ov_task.model.get_data("visual_prompting_prompt_getter.bin") == b"compressed_visual_prompting_prompt_getter.bin" ) assert ( - self.visual_prompting_ov_task.model.get_data("visual_prompting_decoder.xml") + self.zero_shot_visual_prompting_ov_task.model.get_data("visual_prompting_decoder.xml") == b"compressed_visual_prompting_decoder.xml" ) assert ( - self.visual_prompting_ov_task.model.get_data("visual_prompting_decoder.bin") + self.zero_shot_visual_prompting_ov_task.model.get_data("visual_prompting_decoder.bin") == b"compressed_visual_prompting_decoder.bin" ) diff --git a/tests/unit/algorithms/visual_prompting/test_helpers.py b/tests/unit/algorithms/visual_prompting/test_helpers.py index c1be0ae3c89..445bc0f2ba1 100644 --- a/tests/unit/algorithms/visual_prompting/test_helpers.py +++ b/tests/unit/algorithms/visual_prompting/test_helpers.py @@ -17,6 +17,8 @@ AnnotationSceneEntity, AnnotationSceneKind, ) +from unittest.mock import Mock +from otx.api.entities.scored_label import ScoredLabel from otx.api.entities.color import Color from otx.api.entities.dataset_item import DatasetItemEntity from otx.api.entities.datasets import DatasetEntity @@ -148,12 +150,8 @@ def __init__(self, use_mask: bool = False): class MockImageEncoder(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - self.backbone = nn.Linear(1, 1) - - def forward(self, *args, **kwargs): - return torch.ones((1, 2, 4, 4)) + def __new__(cls, *args, **kwargs): + return nn.Linear(4, 4) class MockPromptEncoder(nn.Module): @@ -186,9 +184,20 @@ def predict_mask(self, *args, **kwargs): class MockScoredLabel: - def __init__(self, label: int, name: str = "background"): + def __init__( + self, + label: int, + name: str = "background", + probability: float = 0.0, + label_source=None, + ): self.name = name - self.id_ = label + self.label = Mock() + self.label.id_ = label + self.label.id = label + self.probability = probability + self.label_source = label_source + self.__class__ = ScoredLabel class MockPromptGetter(nn.Module): @@ -202,7 +211,7 @@ def set_default_thresholds(self, *args, **kwargs): pass def get_prompt_candidates(self, *args, **kwargs): - return {1: (torch.Tensor([[0, 0, 0.5]]), torch.Tensor([[1, 1]]))} + return {1: torch.Tensor([[0, 0, 0.5]])}, {1: torch.Tensor([[1, 1]])} def forward(self, *args, **kwargs): return torch.tensor([[[0, 0, 0.5], [1, 1, 0.7]]]), torch.tensor([[[2, 2]]])