From 68dcbe67409afd7d401ed85b501b7d0f838f29fe Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Thu, 25 Jul 2024 14:37:17 +0900 Subject: [PATCH 01/14] Update license and docstring --- .../algo/visual_prompting/segment_anything.py | 2 +- .../zero_shot_segment_anything.py | 2 +- src/otx/core/data/dataset/visual_prompting.py | 20 +++++++++++++++++-- src/otx/core/model/visual_prompting.py | 4 ++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/otx/algo/visual_prompting/segment_anything.py b/src/otx/algo/visual_prompting/segment_anything.py index be318a1d2ea..baac47fe114 100644 --- a/src/otx/algo/visual_prompting/segment_anything.py +++ b/src/otx/algo/visual_prompting/segment_anything.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Segment Anything model for the OTX visual prompting.""" diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index dd650486e30..846e32c6e9d 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Segment Anything model for the OTX zero-shot visual prompting.""" diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index be7c3a1f0c5..1eba5a7f255 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -1,6 +1,6 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# + """Module for OTXVisualPromptingDataset.""" from __future__ import annotations @@ -38,6 +38,14 @@ class OTXVisualPromptingDataset(OTXDataset[VisualPromptingDataEntity]): Args: dm_subset (dmDataset): The subset of the dataset. transforms (Transforms): Data transformations to be applied. + use_bbox (bool): Whether to use bounding box prompt. + If both use_bbox and use_point are False, use_bbox is set to True as default. + If both are True, divide the probability into both. + Defaults to True. + use_point (bool): Whether to use point prompt. + If both use_bbox and use_point are False, use_bbox is set to True as default. + If both are True, divide the probability into both. + Defaults to False. **kwargs: Additional keyword arguments passed to the base class. """ @@ -168,6 +176,14 @@ class OTXZeroShotVisualPromptingDataset(OTXDataset[ZeroShotVisualPromptingDataEn Args: dm_subset (dmDataset): The subset of the dataset. transforms (Transforms): Data transformations to be applied. + use_bbox (bool): Whether to use bounding box prompt. + If both use_bbox and use_point are False, use_bbox is set to True as default. + If both are True, divide the probability into both. + Defaults to True. + use_point (bool): Whether to use point prompt. + If both use_bbox and use_point are False, use_bbox is set to True as default. + If both are True, divide the probability into both. + Defaults to False. **kwargs: Additional keyword arguments passed to the base class. """ diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 7a4fa917993..5e9bcb500a7 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -1,6 +1,6 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# + """Class definition for visual prompting models entity used in OTX.""" from __future__ import annotations From bf966ae3d8fb4a718c2b69bc76f3acffb8da9adb Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Thu, 25 Jul 2024 16:03:47 +0900 Subject: [PATCH 02/14] Update dataset and entity --- src/otx/core/data/dataset/visual_prompting.py | 47 ++++++------ src/otx/core/data/entity/visual_prompting.py | 76 +++++++++---------- 2 files changed, 60 insertions(+), 63 deletions(-) diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index 1eba5a7f255..290315ee239 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -10,7 +10,6 @@ from typing import Callable import torch -import torchvision.transforms.v2.functional as F # noqa: N812 from datumaro import Bbox as dmBbox from datumaro import Dataset as dmDataset from datumaro import Image as dmImage @@ -18,6 +17,9 @@ from datumaro import Points as dmPoints from datumaro import Polygon as dmPolygon from torchvision import tv_tensors +from torchvision.transforms.v2.functional import convert_bounding_box_format, to_image +from torchvision.tv_tensors import BoundingBoxes as tvBoundingBoxes +from torchvision.tv_tensors import Mask as tvMask from otx.core.data.entity.base import ImageInfo, Points from otx.core.data.entity.visual_prompting import ( @@ -84,7 +86,7 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: for annotation in item.annotations: if isinstance(annotation, dmPolygon): - mask = tv_tensors.Mask(polygon_to_bitmap([annotation], *img_shape)[0]) + mask = tvMask(polygon_to_bitmap([annotation], *img_shape)[0]) mask_points = torch.nonzero(mask) if len(mask_points[0]) == 0: # skip very small region @@ -92,16 +94,13 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: if torch.rand(1) < self.prob: # get bbox - bbox = tv_tensors.BoundingBoxes( + bbox = tvBoundingBoxes( annotation.get_bbox(), - format=tv_tensors.BoundingBoxFormat.XYWH, + format="xywh", canvas_size=img_shape, dtype=torch.float32, ) - bbox = F._meta.convert_bounding_box_format( # noqa: SLF001 - bbox, - new_format=tv_tensors.BoundingBoxFormat.XYXY, - ) + bbox = convert_bounding_box_format(bbox, new_format="xyxy") gt_bboxes.append(bbox) gt_labels["bboxes"].append(annotation.label) gt_masks["bboxes"].append(mask) @@ -135,7 +134,7 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: bboxes = tv_tensors.wrap(torch.cat(gt_bboxes, dim=0), like=gt_bboxes[0]) if len(gt_bboxes) > 0 else None points = tv_tensors.wrap(torch.stack(gt_points, dim=0), like=gt_points[0]) if len(gt_points) > 0 else None labels = {prompt_type: torch.as_tensor(values, dtype=torch.int64) for prompt_type, values in gt_labels.items()} - masks = tv_tensors.Mask( + masks = tvMask( torch.stack(gt_masks.get("bboxes", []) + gt_masks.get("points", []), dim=0), dtype=torch.uint8, ) @@ -215,11 +214,14 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None img = item.media_as(dmImage) img_data, img_shape = self._get_img_data_and_shape(img) - gt_prompts, gt_masks, gt_polygons, gt_labels = [], [], [], [] + gt_prompts: list[tvBoundingBoxes | Points] = [] + gt_masks: list[tvMask] = [] + gt_polygons: list[dmPolygon] = [] + gt_labels: dict[str, list[int]] = defaultdict(list) for annotation in item.annotations: if isinstance(annotation, dmPolygon): # generate prompts from polygon - mask = tv_tensors.Mask(polygon_to_bitmap([annotation], *img_shape)[0]) + mask = tvMask(polygon_to_bitmap([annotation], *img_shape)[0]) mask_points = torch.nonzero(mask) if len(mask_points[0]) == 0: # skip very small region @@ -227,16 +229,13 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None if torch.rand(1) < self.prob: # get bbox - bbox = tv_tensors.BoundingBoxes( + bbox = tvBoundingBoxes( annotation.get_bbox(), - format=tv_tensors.BoundingBoxFormat.XYWH, + format="xywh", canvas_size=img_shape, dtype=torch.float32, ) - bbox = F._meta.convert_bounding_box_format( # noqa: SLF001 - bbox, - new_format=tv_tensors.BoundingBoxFormat.XYXY, - ) + bbox = convert_bounding_box_format(bbox, new_format="xyxy") gt_prompts.append(bbox) else: # get center point @@ -247,7 +246,9 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None ) gt_prompts.append(point) - gt_labels.append(annotation.label) + gt_labels["prompts"].append(annotation.label) + gt_labels["polygons"].append(annotation.label) + gt_labels["masks"].append(annotation.label) gt_masks.append(mask) gt_polygons.append(annotation) @@ -255,14 +256,14 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None elif isinstance(annotation, (dmBbox, dmMask, dmPoints)): pass - assert len(gt_prompts) > 0, "#prompts must be greater than 0." # noqa: S101 + if len(gt_prompts) == 0: + return None - labels = torch.as_tensor(gt_labels, dtype=torch.int64) - masks = tv_tensors.Mask(torch.stack(gt_masks, dim=0), dtype=torch.uint8) + labels = {prompt_type: torch.as_tensor(values, dtype=torch.int64) for prompt_type, values in gt_labels.items()} + masks = tvMask(torch.stack(gt_masks, dim=0), dtype=torch.uint8) - # set entity without masks to avoid resizing masks return ZeroShotVisualPromptingDataEntity( - image=F.to_image(img_data), + image=to_image(img_data), img_info=ImageInfo( img_idx=index, img_shape=img_shape, diff --git a/src/otx/core/data/entity/visual_prompting.py b/src/otx/core/data/entity/visual_prompting.py index 2924a65ab64..09e4b68cb56 100644 --- a/src/otx/core/data/entity/visual_prompting.py +++ b/src/otx/core/data/entity/visual_prompting.py @@ -1,6 +1,6 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# + """Module for OTX visual prompting data entities.""" from __future__ import annotations @@ -10,19 +10,15 @@ from torchvision import tv_tensors -from otx.core.data.entity.base import ( - OTXBatchDataEntity, - OTXBatchPredEntity, - OTXDataEntity, - OTXPredEntity, - Points, -) +from otx.core.data.entity.base import OTXBatchDataEntity, OTXBatchPredEntity, OTXDataEntity, OTXPredEntity, Points from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType if TYPE_CHECKING: - from datumaro import Polygon + from datumaro import Polygon as dmPolygon from torch import LongTensor + from torchvision.tv_tensors import BoundingBoxes as tvBoundingBoxes + from torchvision.tv_tensors import Mask as tvMask @register_pytree_node @@ -31,10 +27,10 @@ class VisualPromptingDataEntity(OTXDataEntity): """Data entity for visual prompting task. Attributes: - masks (tv_tensors.Mask): The masks of the instances. - labels (LongTensor): The labels of the instances. - polygons (list[Polygon]): The polygons of the instances. - bboxes (tv_tensors.BoundingBoxes): The bounding boxes of the instances. + masks (tvMask): The masks of the instances. + labels (dict[str, LongTensor]): The labels of the instances. + polygons (list[dmPolygon]): The polygons of the instances. + bboxes (tvBoundingBoxes): The bounding boxes of the instances. points (Points): The points of the instances. """ @@ -43,10 +39,10 @@ def task(self) -> OTXTaskType: """OTX Task type definition.""" return OTXTaskType.VISUAL_PROMPTING - masks: tv_tensors.Mask + masks: tvMask labels: dict[str, LongTensor] - polygons: list[Polygon] - bboxes: tv_tensors.BoundingBoxes + polygons: list[dmPolygon] + bboxes: tvBoundingBoxes points: Points @@ -60,17 +56,17 @@ class VisualPromptingBatchDataEntity(OTXBatchDataEntity[VisualPromptingDataEntit """Data entity for visual prompting task. Attributes: - masks (list[tv_tensors.Mask]): List of masks. - labels (list[LongTensor]): List of labels. - polygons (list[list[Polygon]]): List of polygons. - bboxes (list[tv_tensors.BoundingBoxes]): List of bounding boxes. + masks (list[tvMask]): List of masks. + labels (list[dict[str, LongTensor]]): List of labels. + polygons (list[list[dmPolygon]]): List of polygons. + bboxes (list[tvBoundingBoxes]): List of bounding boxes. points (list[Points]): List of points. """ - masks: list[tv_tensors.Mask] + masks: list[tvMask] labels: list[dict[str, LongTensor]] - polygons: list[list[Polygon]] - bboxes: list[tv_tensors.BoundingBoxes] + polygons: list[list[dmPolygon]] + bboxes: list[tvBoundingBoxes] points: list[Points] @property @@ -137,10 +133,10 @@ class ZeroShotVisualPromptingDataEntity(OTXDataEntity): """Data entity for zero-shot visual prompting task. Attributes: - masks (tv_tensors.Mask): The masks of the instances. - labels (LongTensor): The labels of the instances. - polygons (list[Polygon]): The polygons of the instances. - prompts (list[tv_tensors.TVTensor]): The prompts of the instances. + masks (tvMask): The masks of the instances. + labels (dict[str, list[LongTensor]]): The labels of the instances for each prompt. + polygons (list[dmPolygon]): The polygons of the instances. + prompts (list[tvBoundingBoxes | Points]): The prompts of the instances. """ @property @@ -148,10 +144,10 @@ def task(self) -> OTXTaskType: """OTX Task type definition.""" return OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING - masks: tv_tensors.Mask - labels: list[LongTensor] - polygons: list[Polygon] - prompts: list[tv_tensors.TVTensor] + masks: tvMask + labels: dict[str, LongTensor] + polygons: list[dmPolygon] + prompts: list[tvBoundingBoxes | Points] @dataclass @@ -159,16 +155,16 @@ class ZeroShotVisualPromptingBatchDataEntity(OTXBatchDataEntity[ZeroShotVisualPr """Data entity for zero-shot visual prompting task. Attributes: - masks (list[tv_tensors.Mask]): List of masks. - labels (list[LongTensor]): List of labels. - polygons (list[list[Polygon]]): List of polygons. - prompts (list[list[tv_tensors.TVTensor]]): List of prompts. + masks (list[tvMask]): List of masks. + labels (list[dict[str, LongTensor]]): List of labels. + polygons (list[list[dmPolygon]]): List of polygons. + prompts (list[list[tvBoundingBoxes | Points]]): List of prompts. """ - masks: list[tv_tensors.Mask] - labels: list[LongTensor] - polygons: list[list[Polygon]] - prompts: list[list[tv_tensors.TVTensor]] + masks: list[tvMask] + labels: list[dict[str, LongTensor]] + polygons: list[list[dmPolygon]] + prompts: list[list[tvBoundingBoxes | Points]] @property def task(self) -> OTXTaskType: From 746458daac3cef87f011072648e68a31f6ba4f18 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Thu, 25 Jul 2024 18:03:49 +0900 Subject: [PATCH 03/14] Update types --- src/otx/core/data/dataset/visual_prompting.py | 4 ++-- src/otx/core/data/entity/visual_prompting.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index 290315ee239..fd368b11dab 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -7,7 +7,7 @@ from collections import defaultdict from functools import partial -from typing import Callable +from typing import Callable, Literal import torch from datumaro import Bbox as dmBbox @@ -217,7 +217,7 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None gt_prompts: list[tvBoundingBoxes | Points] = [] gt_masks: list[tvMask] = [] gt_polygons: list[dmPolygon] = [] - gt_labels: dict[str, list[int]] = defaultdict(list) + gt_labels: dict[Literal["prompts", "polygons", "masks"], list[int]] = defaultdict(list) for annotation in item.annotations: if isinstance(annotation, dmPolygon): # generate prompts from polygon diff --git a/src/otx/core/data/entity/visual_prompting.py b/src/otx/core/data/entity/visual_prompting.py index 09e4b68cb56..0a6710baab0 100644 --- a/src/otx/core/data/entity/visual_prompting.py +++ b/src/otx/core/data/entity/visual_prompting.py @@ -6,7 +6,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from torchvision import tv_tensors @@ -134,7 +134,8 @@ class ZeroShotVisualPromptingDataEntity(OTXDataEntity): Attributes: masks (tvMask): The masks of the instances. - labels (dict[str, list[LongTensor]]): The labels of the instances for each prompt. + labels (dict[Literal["prompts", "polygons", "masks"], list[LongTensor]]): The labels of the instances + for each prompt. polygons (list[dmPolygon]): The polygons of the instances. prompts (list[tvBoundingBoxes | Points]): The prompts of the instances. """ @@ -145,7 +146,7 @@ def task(self) -> OTXTaskType: return OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING masks: tvMask - labels: dict[str, LongTensor] + labels: dict[Literal["prompts", "polygons", "masks"], LongTensor] polygons: list[dmPolygon] prompts: list[tvBoundingBoxes | Points] @@ -156,13 +157,13 @@ class ZeroShotVisualPromptingBatchDataEntity(OTXBatchDataEntity[ZeroShotVisualPr Attributes: masks (list[tvMask]): List of masks. - labels (list[dict[str, LongTensor]]): List of labels. + labels (list[dict[Literal["prompts", "polygons", "masks"], LongTensor]]): List of labels. polygons (list[list[dmPolygon]]): List of polygons. prompts (list[list[tvBoundingBoxes | Points]]): List of prompts. """ masks: list[tvMask] - labels: list[dict[str, LongTensor]] + labels: list[dict[Literal["prompts", "polygons", "masks"], LongTensor]] polygons: list[list[dmPolygon]] prompts: list[list[tvBoundingBoxes | Points]] @@ -207,7 +208,7 @@ def pin_memory(self) -> ZeroShotVisualPromptingBatchDataEntity: for prompts in self.prompts ], masks=[tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks], - labels=[label.pin_memory() for label in self.labels], + labels=[{k: v.pin_memory() for k, v in label.items()} for label in self.labels], ) ) From a97cb8817f9dffd5489901de33d83505e359aa2f Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Thu, 25 Jul 2024 18:03:54 +0900 Subject: [PATCH 04/14] Update for modified data entity --- .../zero_shot_segment_anything.py | 76 +++++++++++++------ src/otx/core/model/visual_prompting.py | 7 +- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 846e32c6e9d..3b321a2e720 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -19,7 +19,7 @@ from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 from torchvision import tv_tensors -from torchvision.tv_tensors import BoundingBoxes, Image, Mask, TVTensor +from torchvision.tv_tensors import BoundingBoxes, Image, Mask from otx.algo.visual_prompting.segment_anything import DEFAULT_CONFIG_SEGMENT_ANYTHING, SegmentAnything from otx.core.data.entity.base import OTXBatchLossEntity, Points @@ -230,7 +230,7 @@ def expand_reference_info(self, reference_feats: Tensor, new_largest_label: int) def learn( self, images: list[Image], - processed_prompts: list[dict[int, list[TVTensor]]], + processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]], reference_feats: Tensor, used_indices: Tensor, ori_shapes: list[Tensor], @@ -244,7 +244,7 @@ def learn( Args: images (list[Image]): List of given images for reference features. - processed_prompts (dict[int, list[TVTensor]]): The class-wise prompts + processed_prompts (dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]): The class-wise prompts processed at OTXZeroShotSegmentAnything._gather_prompts_with_labels. reference_feats (Tensor): Reference features for target prediction. used_indices (Tensor): To check which indices of reference features are validate. @@ -252,7 +252,7 @@ def learn( is_cascade (bool): Whether use cascade inference. Defaults to False. """ # initialize tensors to contain reference features and prompts - largest_label = max(sum([[int(p) for p in prompt] for prompt in processed_prompts], [])) + largest_label = max([max(prompt.keys()) for prompt in processed_prompts]) reference_feats = self.expand_reference_info(reference_feats, largest_label) new_used_indices: list[Tensor] = [] # TODO (sungchul): consider how to handle multiple reference features, currently replace it @@ -270,7 +270,9 @@ def learn( for input_prompt in input_prompts: if isinstance(input_prompt, Mask): # directly use annotation information as a mask - ref_mask[input_prompt == 1] += 1 # TODO (sungchul): check if the mask is bool or int + ref_mask[input_prompt] += 1 + elif isinstance(input_prompt, dmPolygon): + continue else: if isinstance(input_prompt, BoundingBoxes): point_coords = input_prompt.reshape(-1, 2, 2) @@ -278,12 +280,6 @@ def learn( elif isinstance(input_prompt, Points): point_coords = input_prompt.reshape(-1, 1, 2) point_labels = torch.tensor([[1]], device=point_coords.device) - elif isinstance( - input_prompt, - dmPolygon, - ): # TODO (sungchul): add other polygon types - # TODO (sungchul): convert polygon to mask - continue else: log.info(f"Current input prompt ({input_prompt.__class__.__name__}) is not supported.") continue @@ -621,6 +617,16 @@ def _decide_cascade_results( best_idx = torch.argmax(scores[0]) return logits[:, [best_idx]], masks[0, best_idx] + # def _polygon_to_mask(self, polygon: np.ndarray | list[np.ndarray], height: int, width: int) -> np.ndarray: + # """Converts a polygon represented as an array of 2D points into a mask.""" + # if isinstance(polygon, np.ndarray) and np.issubdtype(polygon.dtype, np.integer): + # contour = polygon.reshape(-1, 2) + # else: + # contour = [[int(point[0]), int(point[1])] for point in polygon] + # gt_mask = np.zeros((height, width), dtype=np.uint8) + # gt_mask = cv2.drawContours(gt_mask, np.asarray([contour]), 0, 1, cv2.FILLED) + # return gt_mask + class OTXZeroShotSegmentAnything(OTXZeroShotVisualPromptingModel): """Zero-Shot Visual Prompting model.""" @@ -745,7 +751,14 @@ def _customize_inputs( # type: ignore[override] if self.training: # learn forward_inputs.update( - {"processed_prompts": self._gather_prompts_with_labels(inputs.prompts, inputs.labels)}, + { + "processed_prompts": self._gather_prompts_with_labels( + inputs.labels, + inputs.prompts, + inputs.polygons, + inputs.masks, + ), + }, ) return forward_inputs @@ -810,17 +823,29 @@ def _customize_outputs( # type: ignore[override] def _gather_prompts_with_labels( self, - prompts: list[list[TVTensor]], - labels: list[Tensor], - ) -> list[dict[int, list[TVTensor]]]: + labels: list[dict[Literal["prompts", "polygons", "masks"], Tensor]], + prompts: list[list[BoundingBoxes | Points]] | None = None, + polygons: list[list[dmPolygon]] | None = None, + masks: list[Mask] | None = None, + ) -> list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]]: """Gather prompts according to labels.""" - total_processed_prompts: list[dict[int, list[TVTensor]]] = [] - for prompt, label in zip(prompts, labels): + total_processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]] = [] + for batch, batch_labels in enumerate(labels): processed_prompts = defaultdict(list) - for _prompt, _label in zip(prompt, label): # type: ignore[arg-type] - processed_prompts[int(_label)].append(_prompt) + for prompt_type, prompt_labels in batch_labels.items(): + _prompts = locals()[prompt_type] + if _prompts is None: + continue + for idx, _label in enumerate(prompt_labels): + if prompt_type in ("prompts", "polygons"): + processed_prompts[int(_label)].append(_prompts[batch][idx]) + else: + # for mask + processed_prompts[int(_label)].append(Mask(_prompts[idx] == _label)) + sorted_processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x)) total_processed_prompts.append(sorted_processed_prompts) + return total_processed_prompts def apply_image(self, image: Image | np.ndarray, target_length: int = 1024) -> Image: @@ -876,14 +901,16 @@ def get_preprocess_shape(self, oldh: int, oldw: int, target_length: int) -> tupl newh = int(newh + 0.5) return (newh, neww) - def preprocess(self, x: Image) -> Image: + def preprocess(self, x: Image | Mask) -> Image | Mask: """Normalize pixel values and pad to a square input.""" - # Normalize colors - x = (x - self.pixel_mean) / self.pixel_std + # TODO (sungchul): get type to convert tensor at L912 + if isinstance(x, Image): + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std # Pad x = self.model.pad_to_square(x) - return Image(x) + return Image(x) # TODO (sungchul): convert type def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShotVisualPromptingBatchDataEntity: """Transforms for ZeroShotVisualPromptingBatchDataEntity.""" @@ -893,6 +920,9 @@ def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShot self.apply_prompts(prompt, info.ori_shape, self.model.image_size) for prompt, info in zip(entity.prompts, entity.imgs_info) ], + # TODO (sungchul): add masks and polygons + masks=None, + polygons=None, ) def initialize_reference_info(self) -> None: diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 5e9bcb500a7..6f9eca69207 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -130,6 +130,9 @@ def _inference_step_for_zero_shot( if not isinstance(preds, ZeroShotVisualPromptingBatchPredEntity): raise TypeError(preds) + # filter labels using corresponding ground truth + inputs.labels = [label["masks"] if len(inputs.masks) > 0 else label["polygons"] for label in inputs.labels] + converted_entities: dict[str, list[dict[str, Tensor]]] = _convert_pred_entity_to_compute_metric(preds, inputs) # type: ignore[assignment] for _name, _metric in metric.items(): @@ -923,7 +926,7 @@ def _customize_inputs( # type: ignore[override] if self.training: points: list[Prompt] = [] bboxes: list[Prompt] = [] - for prompt, label in zip(prompts, labels): # type: ignore[arg-type] + for prompt, label in zip(prompts, labels["prompts"]): # type: ignore[arg-type] if isinstance(prompt, tv_tensors.BoundingBoxes): bboxes.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) elif isinstance(prompt, Points): @@ -1047,7 +1050,7 @@ def transform_fn( _labels: dict[str, list[int]] = defaultdict(list) # use only the first prompt - for prompt, label in zip(data_batch.prompts[0], data_batch.labels[0]): # type: ignore[arg-type] + for prompt, label in zip(data_batch.prompts[0], data_batch.labels[0]["prompts"]): # type: ignore[arg-type] if isinstance(prompt, tv_tensors.BoundingBoxes): bboxes.append(prompt.cpu().numpy()) _labels["bboxes"].append(label.cpu().numpy()) From 409063c4aae0eae1ed17cea733776085588bf7aa Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Fri, 26 Jul 2024 15:56:02 +0900 Subject: [PATCH 05/14] Enable to use bitmap mask as reference mask --- .../algo/visual_prompting/zero_shot_segment_anything.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 3b321a2e720..f3d075d0bda 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -272,6 +272,7 @@ def learn( # directly use annotation information as a mask ref_mask[input_prompt] += 1 elif isinstance(input_prompt, dmPolygon): + # TODO (sungchul): polygon support continue else: if isinstance(input_prompt, BoundingBoxes): @@ -841,7 +842,7 @@ def _gather_prompts_with_labels( processed_prompts[int(_label)].append(_prompts[batch][idx]) else: # for mask - processed_prompts[int(_label)].append(Mask(_prompts[idx] == _label)) + processed_prompts[int(_label)].append(Mask(_prompts[batch][idx])) sorted_processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x)) total_processed_prompts.append(sorted_processed_prompts) @@ -920,9 +921,8 @@ def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShot self.apply_prompts(prompt, info.ori_shape, self.model.image_size) for prompt, info in zip(entity.prompts, entity.imgs_info) ], - # TODO (sungchul): add masks and polygons - masks=None, - polygons=None, + masks=entity.masks, + polygons=entity.polygons, ) def initialize_reference_info(self) -> None: From 2f6bd130b4296017c83bd655466324d72cadbfd6 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Fri, 26 Jul 2024 16:09:41 +0900 Subject: [PATCH 06/14] Enable to use polygon as reference mask --- .../visual_prompting/zero_shot_segment_anything.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index f3d075d0bda..02abbbb328d 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -32,6 +32,7 @@ from otx.core.model.visual_prompting import OTXZeroShotVisualPromptingModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import LabelInfoTypes, NullLabelInfo +from otx.core.utils.mask_util import polygon_to_bitmap if TYPE_CHECKING: import numpy as np @@ -272,8 +273,7 @@ def learn( # directly use annotation information as a mask ref_mask[input_prompt] += 1 elif isinstance(input_prompt, dmPolygon): - # TODO (sungchul): polygon support - continue + ref_mask[torch.as_tensor(polygon_to_bitmap([input_prompt], *ori_shape)[0])] += 1 else: if isinstance(input_prompt, BoundingBoxes): point_coords = input_prompt.reshape(-1, 2, 2) @@ -618,16 +618,6 @@ def _decide_cascade_results( best_idx = torch.argmax(scores[0]) return logits[:, [best_idx]], masks[0, best_idx] - # def _polygon_to_mask(self, polygon: np.ndarray | list[np.ndarray], height: int, width: int) -> np.ndarray: - # """Converts a polygon represented as an array of 2D points into a mask.""" - # if isinstance(polygon, np.ndarray) and np.issubdtype(polygon.dtype, np.integer): - # contour = polygon.reshape(-1, 2) - # else: - # contour = [[int(point[0]), int(point[1])] for point in polygon] - # gt_mask = np.zeros((height, width), dtype=np.uint8) - # gt_mask = cv2.drawContours(gt_mask, np.asarray([contour]), 0, 1, cv2.FILLED) - # return gt_mask - class OTXZeroShotSegmentAnything(OTXZeroShotVisualPromptingModel): """Zero-Shot Visual Prompting model.""" From a425190b5050f81a5f10cfd3fd6432bfb3e43338 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Fri, 26 Jul 2024 16:23:06 +0900 Subject: [PATCH 07/14] Revert unnecessary changes --- .../visual_prompting/zero_shot_segment_anything.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 02abbbb328d..a5aed3cc852 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -892,16 +892,14 @@ def get_preprocess_shape(self, oldh: int, oldw: int, target_length: int) -> tupl newh = int(newh + 0.5) return (newh, neww) - def preprocess(self, x: Image | Mask) -> Image | Mask: + def preprocess(self, x: Image) -> Image: """Normalize pixel values and pad to a square input.""" - # TODO (sungchul): get type to convert tensor at L912 - if isinstance(x, Image): - # Normalize colors - x = (x - self.pixel_mean) / self.pixel_std + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std # Pad x = self.model.pad_to_square(x) - return Image(x) # TODO (sungchul): convert type + return Image(x) def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShotVisualPromptingBatchDataEntity: """Transforms for ZeroShotVisualPromptingBatchDataEntity.""" From 034bf15304645ca89ef20de5d9b3aac0a52578c7 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Fri, 26 Jul 2024 17:40:46 +0900 Subject: [PATCH 08/14] Support polygons for ov model --- src/otx/core/model/visual_prompting.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 6f9eca69207..85a75df0417 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -14,6 +14,7 @@ import numpy as np import torch +from datumaro import Polygon as dmPolygon from model_api.models import Model from model_api.models.visual_prompting import ( Prompt, @@ -914,9 +915,10 @@ def _customize_inputs( # type: ignore[override] images: list[np.ndarray] = [] processed_prompts: list[dict[str, Any]] = [] - for image, prompts, labels in zip( + for image, prompts, polygons, labels in zip( entity.images, entity.prompts, + entity.polygons, entity.labels, ): # preprocess image encoder inputs @@ -924,20 +926,27 @@ def _customize_inputs( # type: ignore[override] images.append(numpy_image) if self.training: - points: list[Prompt] = [] - bboxes: list[Prompt] = [] + _bboxes: list[Prompt] = [] + _points: list[Prompt] = [] + _polygons: list[Prompt] = [] for prompt, label in zip(prompts, labels["prompts"]): # type: ignore[arg-type] if isinstance(prompt, tv_tensors.BoundingBoxes): - bboxes.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) + _bboxes.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) elif isinstance(prompt, Points): - points.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) - # TODO (sungchul): support polygons + _points.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) + + if polygons: + for polygon, label in zip(polygons, labels["polygons"]): + _polygons.append(Prompt(np.array(polygon.points, dtype=np.int32), label.cpu().numpy())) + + # TODO (sungchul, sovrasov): support mask? # preprocess decoder inputs processed_prompts.append( { - "boxes": bboxes, - "points": points, + "boxes": _bboxes, + "points": _points, + "polygons": _polygons, }, ) From 58934910c3c08d3a7106c863539f28d179492ce3 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Fri, 26 Jul 2024 18:02:40 +0900 Subject: [PATCH 09/14] Fix unit tests --- tests/unit/algo/visual_prompting/conftest.py | 120 +----------------- .../test_zero_shot_segment_anything.py | 4 +- tests/unit/core/conftest.py | 12 +- .../data/dataset/test_visual_prompting.py | 3 +- .../unit/core/model/test_visual_prompting.py | 8 +- 5 files changed, 18 insertions(+), 129 deletions(-) diff --git a/tests/unit/algo/visual_prompting/conftest.py b/tests/unit/algo/visual_prompting/conftest.py index f9567775649..8e31db96703 100644 --- a/tests/unit/algo/visual_prompting/conftest.py +++ b/tests/unit/algo/visual_prompting/conftest.py @@ -1,122 +1,6 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations +from tests.unit.core.conftest import fxt_vpm_data_entity, fxt_zero_shot_vpm_data_entity -import pytest -import torch -from otx.core.data.entity.base import ImageInfo, Points -from otx.core.data.entity.visual_prompting import ( - VisualPromptingBatchDataEntity, - VisualPromptingBatchPredEntity, - VisualPromptingDataEntity, - ZeroShotVisualPromptingBatchDataEntity, - ZeroShotVisualPromptingBatchPredEntity, - ZeroShotVisualPromptingDataEntity, -) -from torchvision import tv_tensors - - -@pytest.fixture(scope="session") -def fxt_vpm_data_entity() -> ( - tuple[VisualPromptingDataEntity, VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity] -): - img_size = (1024, 1024) - fake_image = tv_tensors.Image(torch.rand(img_size)) - fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) - fake_bboxes = tv_tensors.BoundingBoxes( - [[0, 0, 1, 1]], - format=tv_tensors.BoundingBoxFormat.XYXY, - canvas_size=img_size, - dtype=torch.float32, - ) - fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) - fake_masks = tv_tensors.Mask(torch.rand(img_size)) - fake_labels = {"bboxes": torch.as_tensor([1], dtype=torch.int64)} - fake_polygons = [None] - # define data entity - single_data_entity = VisualPromptingDataEntity( - image=fake_image, - img_info=fake_image_info, - masks=fake_masks, - labels=fake_labels, - polygons=fake_polygons, - bboxes=fake_bboxes, - points=fake_points, - ) - batch_data_entity = VisualPromptingBatchDataEntity( - batch_size=1, - images=[fake_image], - imgs_info=[fake_image_info], - masks=[fake_masks], - labels=[fake_labels], - polygons=[fake_polygons], - bboxes=[fake_bboxes], - points=[fake_points], - ) - batch_pred_data_entity = VisualPromptingBatchPredEntity( - batch_size=1, - images=[fake_image], - imgs_info=[fake_image_info], - masks=[fake_masks], - labels=[fake_labels], - polygons=[fake_polygons], - bboxes=[fake_bboxes], - points=[fake_points], - scores=[], - ) - - return single_data_entity, batch_data_entity, batch_pred_data_entity - - -@pytest.fixture(scope="session") -def fxt_zero_shot_vpm_data_entity() -> ( - tuple[ - ZeroShotVisualPromptingDataEntity, - ZeroShotVisualPromptingBatchDataEntity, - ZeroShotVisualPromptingBatchPredEntity, - ] -): - img_size = (1024, 1024) - fake_image = tv_tensors.Image(torch.rand(img_size)) - fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) - fake_bboxes = tv_tensors.BoundingBoxes( - [[0, 0, 1, 1]], - format=tv_tensors.BoundingBoxFormat.XYXY, - canvas_size=img_size, - dtype=torch.float32, - ) - fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) - fake_masks = tv_tensors.Mask(torch.rand(img_size)) - fake_labels = torch.as_tensor([1], dtype=torch.int64) - fake_polygons = [None] - # define data entity - single_data_entity = ZeroShotVisualPromptingDataEntity( - image=fake_image, - img_info=fake_image_info, - masks=fake_masks, - labels=fake_labels, - polygons=fake_polygons, - prompts=[fake_bboxes, fake_points], - ) - batch_data_entity = ZeroShotVisualPromptingBatchDataEntity( - batch_size=1, - images=[fake_image], - imgs_info=[fake_image_info], - masks=[fake_masks], - labels=[fake_labels], - polygons=[fake_polygons], - prompts=[[fake_bboxes, fake_points]], - ) - batch_pred_data_entity = ZeroShotVisualPromptingBatchPredEntity( - batch_size=1, - images=[fake_image], - imgs_info=[fake_image_info], - masks=[fake_masks], - labels=[fake_labels], - polygons=[fake_polygons], - prompts=[[fake_bboxes, fake_points]], - scores=[], - ) - - return single_data_entity, batch_data_entity, batch_pred_data_entity +__all__ = ["fxt_vpm_data_entity", "fxt_zero_shot_vpm_data_entity"] diff --git a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py index b9fe3a5e58e..4a7f6ac1bc1 100644 --- a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py +++ b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py @@ -579,9 +579,9 @@ def test_customize_outputs(self, model, fxt_zero_shot_vpm_data_entity) -> None: def test_gather_prompts_with_labels(self, model) -> None: """Test _gather_prompts_with_labels.""" prompts = [[torch.tensor(0), torch.tensor(1), torch.tensor(2), torch.tensor(2), torch.tensor(4)]] - labels = [torch.tensor([0, 1, 2, 2, 4])] + labels = [{"prompts": torch.tensor([0, 1, 2, 2, 4])}] - results = model._gather_prompts_with_labels(prompts, labels) + results = model._gather_prompts_with_labels(labels, prompts) assert results[0][0][0] == prompts[0][0] assert results[0][1][0] == prompts[0][1] diff --git a/tests/unit/core/conftest.py b/tests/unit/core/conftest.py index d571d9630bf..40dc9b1066a 100644 --- a/tests/unit/core/conftest.py +++ b/tests/unit/core/conftest.py @@ -5,7 +5,7 @@ import numpy as np import pytest import torch -from datumaro import Label +from datumaro import Label, Polygon from datumaro.components.annotation import AnnotationType, LabelCategories from datumaro.components.dataset import Dataset, DatasetItem from datumaro.components.media import Image @@ -151,8 +151,12 @@ def fxt_zero_shot_vpm_data_entity() -> ( ) fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) fake_masks = tv_tensors.Mask(torch.ones(1, *img_size)) - fake_labels = torch.as_tensor([1, 2], dtype=torch.int64) - fake_polygons = [None] + fake_labels = { + "prompts": torch.as_tensor([1, 2], dtype=torch.int64), + "masks": torch.as_tensor([1], dtype=torch.int64), + "polygons": torch.as_tensor([2], dtype=torch.int64), + } + fake_polygons = [Polygon(points=[1, 1, 1, 2, 2, 2, 2, 1])] fake_scores = torch.tensor([[1.0]]) # define data entity single_data_entity = ZeroShotVisualPromptingDataEntity( @@ -177,7 +181,7 @@ def fxt_zero_shot_vpm_data_entity() -> ( images=[fake_image], imgs_info=[fake_image_info], masks=[fake_masks], - labels=[fake_labels], + labels=[fake_labels["prompts"]], polygons=[fake_polygons], prompts=[[fake_bboxes, fake_points]], scores=[fake_scores], diff --git a/tests/unit/core/data/dataset/test_visual_prompting.py b/tests/unit/core/data/dataset/test_visual_prompting.py index 6ff53b96014..5b12e5487a5 100644 --- a/tests/unit/core/data/dataset/test_visual_prompting.py +++ b/tests/unit/core/data/dataset/test_visual_prompting.py @@ -9,7 +9,6 @@ from datumaro import Dataset as DmDataset from otx.core.data.dataset.visual_prompting import OTXVisualPromptingDataset, OTXZeroShotVisualPromptingDataset from otx.core.data.entity.base import ImageInfo, Points -from torch import Tensor from torchvision.transforms.v2 import Identity, Transform from torchvision.tv_tensors import BoundingBoxes, Image, Mask @@ -103,7 +102,7 @@ def test_get_item_impl_subset( assert hasattr(entity, "masks") assert isinstance(entity.masks, Mask) assert hasattr(entity, "labels") - assert isinstance(entity.labels, Tensor) + assert isinstance(entity.labels, dict) assert hasattr(entity, "polygons") assert isinstance(entity.polygons, list) assert hasattr(entity, "prompts") diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index 93f169cfdd8..27fe347a896 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -120,7 +120,9 @@ def test_inference_step_for_zero_shot_with_more_target( ) -> None: """Test _inference_step_for_zero_shot with more target.""" otx_visual_prompting_model.configure_metric() - mocker.patch.object(otx_visual_prompting_model, "forward", return_value=fxt_zero_shot_vpm_data_entity[2]) + pred_entity = deepcopy(fxt_zero_shot_vpm_data_entity[2]) + pred_entity.labels = [label["prompts"] for label in pred_entity.labels] + mocker.patch.object(otx_visual_prompting_model, "forward", return_value=pred_entity) mocker_updates = {} for k, v in otx_visual_prompting_model.metric.items(): mocker_updates[k] = mocker.patch.object(v, "update") @@ -460,9 +462,9 @@ def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ reset_feat=False, ) - assert reference_info["reference_feats"].shape == torch.Size((2, 1, 256)) + assert reference_info["reference_feats"].shape == torch.Size((3, 1, 256)) assert 1 in reference_info["used_indices"] - assert ref_masks[0].shape == torch.Size((2, 1024, 1024)) + assert ref_masks[0].shape == torch.Size((3, 1024, 1024)) def test_infer(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: """Test infer.""" From 869618cefbfbee9a40e16bcc824a0bf505611e5a Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 30 Jul 2024 10:36:31 +0900 Subject: [PATCH 10/14] Revert `BoundingBoxFormat` --- src/otx/core/data/dataset/visual_prompting.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index fd368b11dab..2512ae7eca8 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -19,6 +19,7 @@ from torchvision import tv_tensors from torchvision.transforms.v2.functional import convert_bounding_box_format, to_image from torchvision.tv_tensors import BoundingBoxes as tvBoundingBoxes +from torchvision.tv_tensors import BoundingBoxFormat as tvBoundingBoxFormat from torchvision.tv_tensors import Mask as tvMask from otx.core.data.entity.base import ImageInfo, Points @@ -96,11 +97,11 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: # get bbox bbox = tvBoundingBoxes( annotation.get_bbox(), - format="xywh", + format=tvBoundingBoxFormat.XYWH, canvas_size=img_shape, dtype=torch.float32, ) - bbox = convert_bounding_box_format(bbox, new_format="xyxy") + bbox = convert_bounding_box_format(bbox, new_format=tvBoundingBoxFormat.XYXY) gt_bboxes.append(bbox) gt_labels["bboxes"].append(annotation.label) gt_masks["bboxes"].append(mask) @@ -231,11 +232,11 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None # get bbox bbox = tvBoundingBoxes( annotation.get_bbox(), - format="xywh", + format=tvBoundingBoxFormat.XYWH, canvas_size=img_shape, dtype=torch.float32, ) - bbox = convert_bounding_box_format(bbox, new_format="xyxy") + bbox = convert_bounding_box_format(bbox, new_format=tvBoundingBoxFormat.XYXY) gt_prompts.append(bbox) else: # get center point From 5afa7927e39ddb1f969d3f56d830f528a58eb164 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 30 Jul 2024 10:57:26 +0900 Subject: [PATCH 11/14] Fix unit test --- .../unit/core/model/test_visual_prompting.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index 27fe347a896..30e9c8a04bc 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -65,17 +65,15 @@ def test_inference_step(mocker, otx_visual_prompting_model, fxt_vpm_data_entity) def test_inference_step_for_zero_shot(mocker, otx_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: """Test _inference_step_for_zero_shot.""" + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) + pred_entity = deepcopy(fxt_zero_shot_vpm_data_entity[2]) otx_visual_prompting_model.configure_metric() - mocker.patch.object(otx_visual_prompting_model, "forward", return_value=fxt_zero_shot_vpm_data_entity[2]) + mocker.patch.object(otx_visual_prompting_model, "forward", return_value=pred_entity) mocker_updates = {} for k, v in otx_visual_prompting_model.metric.items(): mocker_updates[k] = mocker.patch.object(v, "update") - _inference_step_for_zero_shot( - otx_visual_prompting_model, - otx_visual_prompting_model.metric, - fxt_zero_shot_vpm_data_entity[1], - ) + _inference_step_for_zero_shot(otx_visual_prompting_model, otx_visual_prompting_model.metric, entity) for v in mocker_updates.values(): v.assert_called_once() @@ -88,8 +86,10 @@ def test_inference_step_for_zero_shot_with_more_preds( ) -> None: """Test _inference_step_for_zero_shot with more preds.""" otx_visual_prompting_model.configure_metric() + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) + pred_entity = deepcopy(fxt_zero_shot_vpm_data_entity[2]) preds = {} - for k, v in fxt_zero_shot_vpm_data_entity[2].__dict__.items(): + for k, v in pred_entity.__dict__.items(): if k in ["batch_size", "polygons"]: preds[k] = v else: @@ -103,11 +103,7 @@ def test_inference_step_for_zero_shot_with_more_preds( for k, v in otx_visual_prompting_model.metric.items(): mocker_updates[k] = mocker.patch.object(v, "update") - _inference_step_for_zero_shot( - otx_visual_prompting_model, - otx_visual_prompting_model.metric, - fxt_zero_shot_vpm_data_entity[1], - ) + _inference_step_for_zero_shot(otx_visual_prompting_model, otx_visual_prompting_model.metric, entity) for v in mocker_updates.values(): v.assert_called_once() @@ -120,14 +116,14 @@ def test_inference_step_for_zero_shot_with_more_target( ) -> None: """Test _inference_step_for_zero_shot with more target.""" otx_visual_prompting_model.configure_metric() + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) pred_entity = deepcopy(fxt_zero_shot_vpm_data_entity[2]) - pred_entity.labels = [label["prompts"] for label in pred_entity.labels] mocker.patch.object(otx_visual_prompting_model, "forward", return_value=pred_entity) mocker_updates = {} for k, v in otx_visual_prompting_model.metric.items(): mocker_updates[k] = mocker.patch.object(v, "update") target = {} - for k, v in fxt_zero_shot_vpm_data_entity[1].__dict__.items(): + for k, v in entity.__dict__.items(): if k in ["batch_size"]: target[k] = v else: @@ -446,6 +442,7 @@ def test_forward( def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: """Test learn.""" + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) ov_zero_shot_visual_prompting_model.model._reference_features = np.zeros((0, 1, 256), dtype=np.float32) ov_zero_shot_visual_prompting_model.model._used_indices = np.array([], dtype=np.int64) ov_zero_shot_visual_prompting_model.model.decoder.mask_threshold = 0.0 @@ -458,7 +455,7 @@ def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ return_value=np.random.rand(1, 256), ) reference_info, ref_masks = ov_zero_shot_visual_prompting_model.learn( - inputs=fxt_zero_shot_vpm_data_entity[1], + inputs=entity, reset_feat=False, ) @@ -468,6 +465,7 @@ def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ def test_infer(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: """Test infer.""" + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) ov_zero_shot_visual_prompting_model.model.decoder.mask_threshold = 0.0 ov_zero_shot_visual_prompting_model.model.decoder.output_blob_name = "upscaled_masks" @@ -488,7 +486,7 @@ def test_infer(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ used_indices = np.array([1]) results = ov_zero_shot_visual_prompting_model.infer( - inputs=fxt_zero_shot_vpm_data_entity[1], + inputs=entity, reference_feats=reference_feats, used_indices=used_indices, ) @@ -502,11 +500,12 @@ def test_customize_outputs_training( ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity, ) -> None: + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) ov_zero_shot_visual_prompting_model.training = True outputs = ({"foo": np.array(1), "bar": np.array(2)}, [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]) - result = ov_zero_shot_visual_prompting_model._customize_outputs(outputs, fxt_zero_shot_vpm_data_entity[1]) + result = ov_zero_shot_visual_prompting_model._customize_outputs(outputs, entity) assert result == outputs From ab08d6845bdca69cecb90482049e6e32d07aa4fe Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 30 Jul 2024 13:55:58 +0900 Subject: [PATCH 12/14] Add `ZeroShotVisualPromptingLabel` --- .../zero_shot_segment_anything.py | 27 +++++++------------ src/otx/core/data/dataset/visual_prompting.py | 7 +++-- src/otx/core/data/entity/visual_prompting.py | 24 ++++++++++++----- src/otx/core/model/visual_prompting.py | 12 +++++---- .../test_zero_shot_segment_anything.py | 18 ++++++------- tests/unit/core/conftest.py | 13 ++++----- .../data/dataset/test_visual_prompting.py | 5 ++-- 7 files changed, 58 insertions(+), 48 deletions(-) diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index a5aed3cc852..6c4bee6151f 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -741,16 +741,7 @@ def _customize_inputs( # type: ignore[override] } if self.training: # learn - forward_inputs.update( - { - "processed_prompts": self._gather_prompts_with_labels( - inputs.labels, - inputs.prompts, - inputs.polygons, - inputs.masks, - ), - }, - ) + forward_inputs.update({"processed_prompts": self._gather_prompts_with_labels(inputs)}) return forward_inputs @@ -814,19 +805,18 @@ def _customize_outputs( # type: ignore[override] def _gather_prompts_with_labels( self, - labels: list[dict[Literal["prompts", "polygons", "masks"], Tensor]], - prompts: list[list[BoundingBoxes | Points]] | None = None, - polygons: list[list[dmPolygon]] | None = None, - masks: list[Mask] | None = None, + inputs: ZeroShotVisualPromptingBatchDataEntity, ) -> list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]]: """Gather prompts according to labels.""" total_processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]] = [] - for batch, batch_labels in enumerate(labels): + for batch, batch_labels in enumerate(inputs.labels): processed_prompts = defaultdict(list) - for prompt_type, prompt_labels in batch_labels.items(): - _prompts = locals()[prompt_type] - if _prompts is None: + for prompt_type in ["prompts", "polygons", "masks"]: + _prompts = getattr(inputs, prompt_type, None) + prompt_labels = getattr(batch_labels, prompt_type, None) + if _prompts is None or prompt_labels is None: continue + for idx, _label in enumerate(prompt_labels): if prompt_type in ("prompts", "polygons"): processed_prompts[int(_label)].append(_prompts[batch][idx]) @@ -911,6 +901,7 @@ def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShot ], masks=entity.masks, polygons=entity.polygons, + labels=entity.labels, ) def initialize_reference_info(self) -> None: diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index 2512ae7eca8..421acc1b02d 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -28,6 +28,7 @@ VisualPromptingDataEntity, ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingDataEntity, + ZeroShotVisualPromptingLabel, ) from otx.core.types.label import NullLabelInfo from otx.core.utils.mask_util import polygon_to_bitmap @@ -260,7 +261,9 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None if len(gt_prompts) == 0: return None - labels = {prompt_type: torch.as_tensor(values, dtype=torch.int64) for prompt_type, values in gt_labels.items()} + labels = { + str(prompt_type): torch.as_tensor(values, dtype=torch.int64) for prompt_type, values in gt_labels.items() + } masks = tvMask(torch.stack(gt_masks, dim=0), dtype=torch.uint8) return ZeroShotVisualPromptingDataEntity( @@ -271,7 +274,7 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None ori_shape=img_shape, ), masks=masks, - labels=labels, + labels=ZeroShotVisualPromptingLabel(**labels), polygons=gt_polygons, prompts=gt_prompts, ) diff --git a/src/otx/core/data/entity/visual_prompting.py b/src/otx/core/data/entity/visual_prompting.py index 0a6710baab0..8f5eed33b57 100644 --- a/src/otx/core/data/entity/visual_prompting.py +++ b/src/otx/core/data/entity/visual_prompting.py @@ -6,7 +6,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from torchvision import tv_tensors @@ -127,6 +127,15 @@ class VisualPromptingBatchPredEntity(OTXBatchPredEntity, VisualPromptingBatchDat """Data entity to represent model output predictions for visual prompting task.""" +@dataclass +class ZeroShotVisualPromptingLabel: + """Label dataclass for zero-shot visual prompting data entity.""" + + prompts: LongTensor | None = None + polygons: LongTensor | None = None + masks: LongTensor | None = None + + @register_pytree_node @dataclass class ZeroShotVisualPromptingDataEntity(OTXDataEntity): @@ -134,7 +143,7 @@ class ZeroShotVisualPromptingDataEntity(OTXDataEntity): Attributes: masks (tvMask): The masks of the instances. - labels (dict[Literal["prompts", "polygons", "masks"], list[LongTensor]]): The labels of the instances + labels (ZeroShotVisualPromptingLabel): The labels of the instances for each prompt. polygons (list[dmPolygon]): The polygons of the instances. prompts (list[tvBoundingBoxes | Points]): The prompts of the instances. @@ -146,7 +155,7 @@ def task(self) -> OTXTaskType: return OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING masks: tvMask - labels: dict[Literal["prompts", "polygons", "masks"], LongTensor] + labels: ZeroShotVisualPromptingLabel polygons: list[dmPolygon] prompts: list[tvBoundingBoxes | Points] @@ -157,13 +166,13 @@ class ZeroShotVisualPromptingBatchDataEntity(OTXBatchDataEntity[ZeroShotVisualPr Attributes: masks (list[tvMask]): List of masks. - labels (list[dict[Literal["prompts", "polygons", "masks"], LongTensor]]): List of labels. + labels (list[ZeroShotVisualPromptingLabel]): List of labels. polygons (list[list[dmPolygon]]): List of polygons. prompts (list[list[tvBoundingBoxes | Points]]): List of prompts. """ masks: list[tvMask] - labels: list[dict[Literal["prompts", "polygons", "masks"], LongTensor]] + labels: list[ZeroShotVisualPromptingLabel] polygons: list[list[dmPolygon]] prompts: list[list[tvBoundingBoxes | Points]] @@ -208,7 +217,10 @@ def pin_memory(self) -> ZeroShotVisualPromptingBatchDataEntity: for prompts in self.prompts ], masks=[tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks], - labels=[{k: v.pin_memory() for k, v in label.items()} for label in self.labels], + labels=[ + ZeroShotVisualPromptingLabel(**{k: v.pin_memory() for k, v in label.__dict__.items()}) + for label in self.labels + ], ) ) diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 85a75df0417..a27bd2a179c 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -132,7 +132,9 @@ def _inference_step_for_zero_shot( raise TypeError(preds) # filter labels using corresponding ground truth - inputs.labels = [label["masks"] if len(inputs.masks) > 0 else label["polygons"] for label in inputs.labels] + inputs.labels = [ + label.masks if inputs.masks and label.masks is not None else label.polygons for label in inputs.labels + ] converted_entities: dict[str, list[dict[str, Tensor]]] = _convert_pred_entity_to_compute_metric(preds, inputs) # type: ignore[assignment] @@ -929,14 +931,14 @@ def _customize_inputs( # type: ignore[override] _bboxes: list[Prompt] = [] _points: list[Prompt] = [] _polygons: list[Prompt] = [] - for prompt, label in zip(prompts, labels["prompts"]): # type: ignore[arg-type] + for prompt, label in zip(prompts, labels.prompts): # type: ignore[arg-type] if isinstance(prompt, tv_tensors.BoundingBoxes): _bboxes.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) elif isinstance(prompt, Points): _points.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) - if polygons: - for polygon, label in zip(polygons, labels["polygons"]): + if polygons and labels.polygons is not None: + for polygon, label in zip(polygons, labels.polygons): _polygons.append(Prompt(np.array(polygon.points, dtype=np.int32), label.cpu().numpy())) # TODO (sungchul, sovrasov): support mask? @@ -1059,7 +1061,7 @@ def transform_fn( _labels: dict[str, list[int]] = defaultdict(list) # use only the first prompt - for prompt, label in zip(data_batch.prompts[0], data_batch.labels[0]["prompts"]): # type: ignore[arg-type] + for prompt, label in zip(data_batch.prompts[0], data_batch.labels[0].prompts): # type: ignore[arg-type] if isinstance(prompt, tv_tensors.BoundingBoxes): bboxes.append(prompt.cpu().numpy()) _labels["bboxes"].append(label.cpu().numpy()) diff --git a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py index 4a7f6ac1bc1..90dbddd65c2 100644 --- a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py +++ b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py @@ -1,8 +1,9 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from copy import deepcopy from typing import Any, Callable import pytest @@ -576,17 +577,16 @@ def test_customize_outputs(self, model, fxt_zero_shot_vpm_data_entity) -> None: assert result.labels[0] == label assert torch.all(result.prompts[0].data == outputs[0][1][label][0][:2].unsqueeze(0)) - def test_gather_prompts_with_labels(self, model) -> None: + def test_gather_prompts_with_labels(self, model, fxt_zero_shot_vpm_data_entity) -> None: """Test _gather_prompts_with_labels.""" - prompts = [[torch.tensor(0), torch.tensor(1), torch.tensor(2), torch.tensor(2), torch.tensor(4)]] - labels = [{"prompts": torch.tensor([0, 1, 2, 2, 4])}] + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) - results = model._gather_prompts_with_labels(labels, prompts) + results = model._gather_prompts_with_labels(entity) - assert results[0][0][0] == prompts[0][0] - assert results[0][1][0] == prompts[0][1] - assert results[0][2] == prompts[0][2:4] - assert results[0][4][0] == prompts[0][4] + assert torch.all(results[0][1][0] == entity.prompts[0][0]) + assert torch.all(results[0][1][1] == entity.masks[0]) + assert torch.all(results[0][2][0] == entity.prompts[0][1]) + assert results[0][2][1] == entity.polygons[0][0] @pytest.mark.parametrize( ("image", "expected"), diff --git a/tests/unit/core/conftest.py b/tests/unit/core/conftest.py index 40dc9b1066a..0dacf5fd15a 100644 --- a/tests/unit/core/conftest.py +++ b/tests/unit/core/conftest.py @@ -18,6 +18,7 @@ ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity, ZeroShotVisualPromptingDataEntity, + ZeroShotVisualPromptingLabel, ) from torchvision import tv_tensors @@ -151,11 +152,11 @@ def fxt_zero_shot_vpm_data_entity() -> ( ) fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) fake_masks = tv_tensors.Mask(torch.ones(1, *img_size)) - fake_labels = { - "prompts": torch.as_tensor([1, 2], dtype=torch.int64), - "masks": torch.as_tensor([1], dtype=torch.int64), - "polygons": torch.as_tensor([2], dtype=torch.int64), - } + fake_labels = ZeroShotVisualPromptingLabel( + prompts=torch.as_tensor([1, 2], dtype=torch.int64), + masks=torch.as_tensor([1], dtype=torch.int64), + polygons=torch.as_tensor([2], dtype=torch.int64), + ) fake_polygons = [Polygon(points=[1, 1, 1, 2, 2, 2, 2, 1])] fake_scores = torch.tensor([[1.0]]) # define data entity @@ -181,7 +182,7 @@ def fxt_zero_shot_vpm_data_entity() -> ( images=[fake_image], imgs_info=[fake_image_info], masks=[fake_masks], - labels=[fake_labels["prompts"]], + labels=[fake_labels.prompts], polygons=[fake_polygons], prompts=[[fake_bboxes, fake_points]], scores=[fake_scores], diff --git a/tests/unit/core/data/dataset/test_visual_prompting.py b/tests/unit/core/data/dataset/test_visual_prompting.py index 5b12e5487a5..a651cf1589b 100644 --- a/tests/unit/core/data/dataset/test_visual_prompting.py +++ b/tests/unit/core/data/dataset/test_visual_prompting.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Unit tests of visual prompting datasets.""" @@ -9,6 +9,7 @@ from datumaro import Dataset as DmDataset from otx.core.data.dataset.visual_prompting import OTXVisualPromptingDataset, OTXZeroShotVisualPromptingDataset from otx.core.data.entity.base import ImageInfo, Points +from otx.core.data.entity.visual_prompting import ZeroShotVisualPromptingLabel from torchvision.transforms.v2 import Identity, Transform from torchvision.tv_tensors import BoundingBoxes, Image, Mask @@ -102,7 +103,7 @@ def test_get_item_impl_subset( assert hasattr(entity, "masks") assert isinstance(entity.masks, Mask) assert hasattr(entity, "labels") - assert isinstance(entity.labels, dict) + assert isinstance(entity.labels, ZeroShotVisualPromptingLabel) assert hasattr(entity, "polygons") assert isinstance(entity.polygons, list) assert hasattr(entity, "prompts") From 3f1333fdee2f85e1fe26267f32fbbab18a368e4a Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 30 Jul 2024 13:59:41 +0900 Subject: [PATCH 13/14] Update condition --- src/otx/core/data/dataset/visual_prompting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index 421acc1b02d..74389d4fbf2 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -258,7 +258,7 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None elif isinstance(annotation, (dmBbox, dmMask, dmPoints)): pass - if len(gt_prompts) == 0: + if not gt_prompts: return None labels = { From 6da4a6aff533dd9932a8f6d1e0d05daeb69ea9d7 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 30 Jul 2024 14:11:00 +0900 Subject: [PATCH 14/14] Update CHANGELOG --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb8fe6a455a..f27500069e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ All notable changes to this project will be documented in this file. () - Enable to use input_size at transforms in recipe () +- Enable to use polygon and bitmap mask as prompt inputs for zero-shot learning + () ### Bug fixes