diff --git a/CHANGELOG.md b/CHANGELOG.md index bb8fe6a455a..f27500069e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ All notable changes to this project will be documented in this file. () - Enable to use input_size at transforms in recipe () +- Enable to use polygon and bitmap mask as prompt inputs for zero-shot learning + () ### Bug fixes diff --git a/src/otx/algo/visual_prompting/segment_anything.py b/src/otx/algo/visual_prompting/segment_anything.py index be318a1d2ea..baac47fe114 100644 --- a/src/otx/algo/visual_prompting/segment_anything.py +++ b/src/otx/algo/visual_prompting/segment_anything.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Segment Anything model for the OTX visual prompting.""" diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index dd650486e30..6c4bee6151f 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Segment Anything model for the OTX zero-shot visual prompting.""" @@ -19,7 +19,7 @@ from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 from torchvision import tv_tensors -from torchvision.tv_tensors import BoundingBoxes, Image, Mask, TVTensor +from torchvision.tv_tensors import BoundingBoxes, Image, Mask from otx.algo.visual_prompting.segment_anything import DEFAULT_CONFIG_SEGMENT_ANYTHING, SegmentAnything from otx.core.data.entity.base import OTXBatchLossEntity, Points @@ -32,6 +32,7 @@ from otx.core.model.visual_prompting import OTXZeroShotVisualPromptingModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import LabelInfoTypes, NullLabelInfo +from otx.core.utils.mask_util import polygon_to_bitmap if TYPE_CHECKING: import numpy as np @@ -230,7 +231,7 @@ def expand_reference_info(self, reference_feats: Tensor, new_largest_label: int) def learn( self, images: list[Image], - processed_prompts: list[dict[int, list[TVTensor]]], + processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]], reference_feats: Tensor, used_indices: Tensor, ori_shapes: list[Tensor], @@ -244,7 +245,7 @@ def learn( Args: images (list[Image]): List of given images for reference features. - processed_prompts (dict[int, list[TVTensor]]): The class-wise prompts + processed_prompts (dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]): The class-wise prompts processed at OTXZeroShotSegmentAnything._gather_prompts_with_labels. reference_feats (Tensor): Reference features for target prediction. used_indices (Tensor): To check which indices of reference features are validate. @@ -252,7 +253,7 @@ def learn( is_cascade (bool): Whether use cascade inference. Defaults to False. """ # initialize tensors to contain reference features and prompts - largest_label = max(sum([[int(p) for p in prompt] for prompt in processed_prompts], [])) + largest_label = max([max(prompt.keys()) for prompt in processed_prompts]) reference_feats = self.expand_reference_info(reference_feats, largest_label) new_used_indices: list[Tensor] = [] # TODO (sungchul): consider how to handle multiple reference features, currently replace it @@ -270,7 +271,9 @@ def learn( for input_prompt in input_prompts: if isinstance(input_prompt, Mask): # directly use annotation information as a mask - ref_mask[input_prompt == 1] += 1 # TODO (sungchul): check if the mask is bool or int + ref_mask[input_prompt] += 1 + elif isinstance(input_prompt, dmPolygon): + ref_mask[torch.as_tensor(polygon_to_bitmap([input_prompt], *ori_shape)[0])] += 1 else: if isinstance(input_prompt, BoundingBoxes): point_coords = input_prompt.reshape(-1, 2, 2) @@ -278,12 +281,6 @@ def learn( elif isinstance(input_prompt, Points): point_coords = input_prompt.reshape(-1, 1, 2) point_labels = torch.tensor([[1]], device=point_coords.device) - elif isinstance( - input_prompt, - dmPolygon, - ): # TODO (sungchul): add other polygon types - # TODO (sungchul): convert polygon to mask - continue else: log.info(f"Current input prompt ({input_prompt.__class__.__name__}) is not supported.") continue @@ -744,9 +741,7 @@ def _customize_inputs( # type: ignore[override] } if self.training: # learn - forward_inputs.update( - {"processed_prompts": self._gather_prompts_with_labels(inputs.prompts, inputs.labels)}, - ) + forward_inputs.update({"processed_prompts": self._gather_prompts_with_labels(inputs)}) return forward_inputs @@ -810,17 +805,28 @@ def _customize_outputs( # type: ignore[override] def _gather_prompts_with_labels( self, - prompts: list[list[TVTensor]], - labels: list[Tensor], - ) -> list[dict[int, list[TVTensor]]]: + inputs: ZeroShotVisualPromptingBatchDataEntity, + ) -> list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]]: """Gather prompts according to labels.""" - total_processed_prompts: list[dict[int, list[TVTensor]]] = [] - for prompt, label in zip(prompts, labels): + total_processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]] = [] + for batch, batch_labels in enumerate(inputs.labels): processed_prompts = defaultdict(list) - for _prompt, _label in zip(prompt, label): # type: ignore[arg-type] - processed_prompts[int(_label)].append(_prompt) + for prompt_type in ["prompts", "polygons", "masks"]: + _prompts = getattr(inputs, prompt_type, None) + prompt_labels = getattr(batch_labels, prompt_type, None) + if _prompts is None or prompt_labels is None: + continue + + for idx, _label in enumerate(prompt_labels): + if prompt_type in ("prompts", "polygons"): + processed_prompts[int(_label)].append(_prompts[batch][idx]) + else: + # for mask + processed_prompts[int(_label)].append(Mask(_prompts[batch][idx])) + sorted_processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x)) total_processed_prompts.append(sorted_processed_prompts) + return total_processed_prompts def apply_image(self, image: Image | np.ndarray, target_length: int = 1024) -> Image: @@ -893,6 +899,9 @@ def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShot self.apply_prompts(prompt, info.ori_shape, self.model.image_size) for prompt, info in zip(entity.prompts, entity.imgs_info) ], + masks=entity.masks, + polygons=entity.polygons, + labels=entity.labels, ) def initialize_reference_info(self) -> None: diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index be7c3a1f0c5..74389d4fbf2 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -1,16 +1,15 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# + """Module for OTXVisualPromptingDataset.""" from __future__ import annotations from collections import defaultdict from functools import partial -from typing import Callable +from typing import Callable, Literal import torch -import torchvision.transforms.v2.functional as F # noqa: N812 from datumaro import Bbox as dmBbox from datumaro import Dataset as dmDataset from datumaro import Image as dmImage @@ -18,6 +17,10 @@ from datumaro import Points as dmPoints from datumaro import Polygon as dmPolygon from torchvision import tv_tensors +from torchvision.transforms.v2.functional import convert_bounding_box_format, to_image +from torchvision.tv_tensors import BoundingBoxes as tvBoundingBoxes +from torchvision.tv_tensors import BoundingBoxFormat as tvBoundingBoxFormat +from torchvision.tv_tensors import Mask as tvMask from otx.core.data.entity.base import ImageInfo, Points from otx.core.data.entity.visual_prompting import ( @@ -25,6 +28,7 @@ VisualPromptingDataEntity, ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingDataEntity, + ZeroShotVisualPromptingLabel, ) from otx.core.types.label import NullLabelInfo from otx.core.utils.mask_util import polygon_to_bitmap @@ -38,6 +42,14 @@ class OTXVisualPromptingDataset(OTXDataset[VisualPromptingDataEntity]): Args: dm_subset (dmDataset): The subset of the dataset. transforms (Transforms): Data transformations to be applied. + use_bbox (bool): Whether to use bounding box prompt. + If both use_bbox and use_point are False, use_bbox is set to True as default. + If both are True, divide the probability into both. + Defaults to True. + use_point (bool): Whether to use point prompt. + If both use_bbox and use_point are False, use_bbox is set to True as default. + If both are True, divide the probability into both. + Defaults to False. **kwargs: Additional keyword arguments passed to the base class. """ @@ -76,7 +88,7 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: for annotation in item.annotations: if isinstance(annotation, dmPolygon): - mask = tv_tensors.Mask(polygon_to_bitmap([annotation], *img_shape)[0]) + mask = tvMask(polygon_to_bitmap([annotation], *img_shape)[0]) mask_points = torch.nonzero(mask) if len(mask_points[0]) == 0: # skip very small region @@ -84,16 +96,13 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: if torch.rand(1) < self.prob: # get bbox - bbox = tv_tensors.BoundingBoxes( + bbox = tvBoundingBoxes( annotation.get_bbox(), - format=tv_tensors.BoundingBoxFormat.XYWH, + format=tvBoundingBoxFormat.XYWH, canvas_size=img_shape, dtype=torch.float32, ) - bbox = F._meta.convert_bounding_box_format( # noqa: SLF001 - bbox, - new_format=tv_tensors.BoundingBoxFormat.XYXY, - ) + bbox = convert_bounding_box_format(bbox, new_format=tvBoundingBoxFormat.XYXY) gt_bboxes.append(bbox) gt_labels["bboxes"].append(annotation.label) gt_masks["bboxes"].append(mask) @@ -127,7 +136,7 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: bboxes = tv_tensors.wrap(torch.cat(gt_bboxes, dim=0), like=gt_bboxes[0]) if len(gt_bboxes) > 0 else None points = tv_tensors.wrap(torch.stack(gt_points, dim=0), like=gt_points[0]) if len(gt_points) > 0 else None labels = {prompt_type: torch.as_tensor(values, dtype=torch.int64) for prompt_type, values in gt_labels.items()} - masks = tv_tensors.Mask( + masks = tvMask( torch.stack(gt_masks.get("bboxes", []) + gt_masks.get("points", []), dim=0), dtype=torch.uint8, ) @@ -168,6 +177,14 @@ class OTXZeroShotVisualPromptingDataset(OTXDataset[ZeroShotVisualPromptingDataEn Args: dm_subset (dmDataset): The subset of the dataset. transforms (Transforms): Data transformations to be applied. + use_bbox (bool): Whether to use bounding box prompt. + If both use_bbox and use_point are False, use_bbox is set to True as default. + If both are True, divide the probability into both. + Defaults to True. + use_point (bool): Whether to use point prompt. + If both use_bbox and use_point are False, use_bbox is set to True as default. + If both are True, divide the probability into both. + Defaults to False. **kwargs: Additional keyword arguments passed to the base class. """ @@ -199,11 +216,14 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None img = item.media_as(dmImage) img_data, img_shape = self._get_img_data_and_shape(img) - gt_prompts, gt_masks, gt_polygons, gt_labels = [], [], [], [] + gt_prompts: list[tvBoundingBoxes | Points] = [] + gt_masks: list[tvMask] = [] + gt_polygons: list[dmPolygon] = [] + gt_labels: dict[Literal["prompts", "polygons", "masks"], list[int]] = defaultdict(list) for annotation in item.annotations: if isinstance(annotation, dmPolygon): # generate prompts from polygon - mask = tv_tensors.Mask(polygon_to_bitmap([annotation], *img_shape)[0]) + mask = tvMask(polygon_to_bitmap([annotation], *img_shape)[0]) mask_points = torch.nonzero(mask) if len(mask_points[0]) == 0: # skip very small region @@ -211,16 +231,13 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None if torch.rand(1) < self.prob: # get bbox - bbox = tv_tensors.BoundingBoxes( + bbox = tvBoundingBoxes( annotation.get_bbox(), - format=tv_tensors.BoundingBoxFormat.XYWH, + format=tvBoundingBoxFormat.XYWH, canvas_size=img_shape, dtype=torch.float32, ) - bbox = F._meta.convert_bounding_box_format( # noqa: SLF001 - bbox, - new_format=tv_tensors.BoundingBoxFormat.XYXY, - ) + bbox = convert_bounding_box_format(bbox, new_format=tvBoundingBoxFormat.XYXY) gt_prompts.append(bbox) else: # get center point @@ -231,7 +248,9 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None ) gt_prompts.append(point) - gt_labels.append(annotation.label) + gt_labels["prompts"].append(annotation.label) + gt_labels["polygons"].append(annotation.label) + gt_labels["masks"].append(annotation.label) gt_masks.append(mask) gt_polygons.append(annotation) @@ -239,21 +258,23 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None elif isinstance(annotation, (dmBbox, dmMask, dmPoints)): pass - assert len(gt_prompts) > 0, "#prompts must be greater than 0." # noqa: S101 + if not gt_prompts: + return None - labels = torch.as_tensor(gt_labels, dtype=torch.int64) - masks = tv_tensors.Mask(torch.stack(gt_masks, dim=0), dtype=torch.uint8) + labels = { + str(prompt_type): torch.as_tensor(values, dtype=torch.int64) for prompt_type, values in gt_labels.items() + } + masks = tvMask(torch.stack(gt_masks, dim=0), dtype=torch.uint8) - # set entity without masks to avoid resizing masks return ZeroShotVisualPromptingDataEntity( - image=F.to_image(img_data), + image=to_image(img_data), img_info=ImageInfo( img_idx=index, img_shape=img_shape, ori_shape=img_shape, ), masks=masks, - labels=labels, + labels=ZeroShotVisualPromptingLabel(**labels), polygons=gt_polygons, prompts=gt_prompts, ) diff --git a/src/otx/core/data/entity/visual_prompting.py b/src/otx/core/data/entity/visual_prompting.py index 2924a65ab64..8f5eed33b57 100644 --- a/src/otx/core/data/entity/visual_prompting.py +++ b/src/otx/core/data/entity/visual_prompting.py @@ -1,6 +1,6 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# + """Module for OTX visual prompting data entities.""" from __future__ import annotations @@ -10,19 +10,15 @@ from torchvision import tv_tensors -from otx.core.data.entity.base import ( - OTXBatchDataEntity, - OTXBatchPredEntity, - OTXDataEntity, - OTXPredEntity, - Points, -) +from otx.core.data.entity.base import OTXBatchDataEntity, OTXBatchPredEntity, OTXDataEntity, OTXPredEntity, Points from otx.core.data.entity.utils import register_pytree_node from otx.core.types.task import OTXTaskType if TYPE_CHECKING: - from datumaro import Polygon + from datumaro import Polygon as dmPolygon from torch import LongTensor + from torchvision.tv_tensors import BoundingBoxes as tvBoundingBoxes + from torchvision.tv_tensors import Mask as tvMask @register_pytree_node @@ -31,10 +27,10 @@ class VisualPromptingDataEntity(OTXDataEntity): """Data entity for visual prompting task. Attributes: - masks (tv_tensors.Mask): The masks of the instances. - labels (LongTensor): The labels of the instances. - polygons (list[Polygon]): The polygons of the instances. - bboxes (tv_tensors.BoundingBoxes): The bounding boxes of the instances. + masks (tvMask): The masks of the instances. + labels (dict[str, LongTensor]): The labels of the instances. + polygons (list[dmPolygon]): The polygons of the instances. + bboxes (tvBoundingBoxes): The bounding boxes of the instances. points (Points): The points of the instances. """ @@ -43,10 +39,10 @@ def task(self) -> OTXTaskType: """OTX Task type definition.""" return OTXTaskType.VISUAL_PROMPTING - masks: tv_tensors.Mask + masks: tvMask labels: dict[str, LongTensor] - polygons: list[Polygon] - bboxes: tv_tensors.BoundingBoxes + polygons: list[dmPolygon] + bboxes: tvBoundingBoxes points: Points @@ -60,17 +56,17 @@ class VisualPromptingBatchDataEntity(OTXBatchDataEntity[VisualPromptingDataEntit """Data entity for visual prompting task. Attributes: - masks (list[tv_tensors.Mask]): List of masks. - labels (list[LongTensor]): List of labels. - polygons (list[list[Polygon]]): List of polygons. - bboxes (list[tv_tensors.BoundingBoxes]): List of bounding boxes. + masks (list[tvMask]): List of masks. + labels (list[dict[str, LongTensor]]): List of labels. + polygons (list[list[dmPolygon]]): List of polygons. + bboxes (list[tvBoundingBoxes]): List of bounding boxes. points (list[Points]): List of points. """ - masks: list[tv_tensors.Mask] + masks: list[tvMask] labels: list[dict[str, LongTensor]] - polygons: list[list[Polygon]] - bboxes: list[tv_tensors.BoundingBoxes] + polygons: list[list[dmPolygon]] + bboxes: list[tvBoundingBoxes] points: list[Points] @property @@ -131,16 +127,26 @@ class VisualPromptingBatchPredEntity(OTXBatchPredEntity, VisualPromptingBatchDat """Data entity to represent model output predictions for visual prompting task.""" +@dataclass +class ZeroShotVisualPromptingLabel: + """Label dataclass for zero-shot visual prompting data entity.""" + + prompts: LongTensor | None = None + polygons: LongTensor | None = None + masks: LongTensor | None = None + + @register_pytree_node @dataclass class ZeroShotVisualPromptingDataEntity(OTXDataEntity): """Data entity for zero-shot visual prompting task. Attributes: - masks (tv_tensors.Mask): The masks of the instances. - labels (LongTensor): The labels of the instances. - polygons (list[Polygon]): The polygons of the instances. - prompts (list[tv_tensors.TVTensor]): The prompts of the instances. + masks (tvMask): The masks of the instances. + labels (ZeroShotVisualPromptingLabel): The labels of the instances + for each prompt. + polygons (list[dmPolygon]): The polygons of the instances. + prompts (list[tvBoundingBoxes | Points]): The prompts of the instances. """ @property @@ -148,10 +154,10 @@ def task(self) -> OTXTaskType: """OTX Task type definition.""" return OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING - masks: tv_tensors.Mask - labels: list[LongTensor] - polygons: list[Polygon] - prompts: list[tv_tensors.TVTensor] + masks: tvMask + labels: ZeroShotVisualPromptingLabel + polygons: list[dmPolygon] + prompts: list[tvBoundingBoxes | Points] @dataclass @@ -159,16 +165,16 @@ class ZeroShotVisualPromptingBatchDataEntity(OTXBatchDataEntity[ZeroShotVisualPr """Data entity for zero-shot visual prompting task. Attributes: - masks (list[tv_tensors.Mask]): List of masks. - labels (list[LongTensor]): List of labels. - polygons (list[list[Polygon]]): List of polygons. - prompts (list[list[tv_tensors.TVTensor]]): List of prompts. + masks (list[tvMask]): List of masks. + labels (list[ZeroShotVisualPromptingLabel]): List of labels. + polygons (list[list[dmPolygon]]): List of polygons. + prompts (list[list[tvBoundingBoxes | Points]]): List of prompts. """ - masks: list[tv_tensors.Mask] - labels: list[LongTensor] - polygons: list[list[Polygon]] - prompts: list[list[tv_tensors.TVTensor]] + masks: list[tvMask] + labels: list[ZeroShotVisualPromptingLabel] + polygons: list[list[dmPolygon]] + prompts: list[list[tvBoundingBoxes | Points]] @property def task(self) -> OTXTaskType: @@ -211,7 +217,10 @@ def pin_memory(self) -> ZeroShotVisualPromptingBatchDataEntity: for prompts in self.prompts ], masks=[tv_tensors.wrap(mask.pin_memory(), like=mask) for mask in self.masks], - labels=[label.pin_memory() for label in self.labels], + labels=[ + ZeroShotVisualPromptingLabel(**{k: v.pin_memory() for k, v in label.__dict__.items()}) + for label in self.labels + ], ) ) diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 7a4fa917993..a27bd2a179c 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -1,6 +1,6 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# + """Class definition for visual prompting models entity used in OTX.""" from __future__ import annotations @@ -14,6 +14,7 @@ import numpy as np import torch +from datumaro import Polygon as dmPolygon from model_api.models import Model from model_api.models.visual_prompting import ( Prompt, @@ -130,6 +131,11 @@ def _inference_step_for_zero_shot( if not isinstance(preds, ZeroShotVisualPromptingBatchPredEntity): raise TypeError(preds) + # filter labels using corresponding ground truth + inputs.labels = [ + label.masks if inputs.masks and label.masks is not None else label.polygons for label in inputs.labels + ] + converted_entities: dict[str, list[dict[str, Tensor]]] = _convert_pred_entity_to_compute_metric(preds, inputs) # type: ignore[assignment] for _name, _metric in metric.items(): @@ -911,9 +917,10 @@ def _customize_inputs( # type: ignore[override] images: list[np.ndarray] = [] processed_prompts: list[dict[str, Any]] = [] - for image, prompts, labels in zip( + for image, prompts, polygons, labels in zip( entity.images, entity.prompts, + entity.polygons, entity.labels, ): # preprocess image encoder inputs @@ -921,20 +928,27 @@ def _customize_inputs( # type: ignore[override] images.append(numpy_image) if self.training: - points: list[Prompt] = [] - bboxes: list[Prompt] = [] - for prompt, label in zip(prompts, labels): # type: ignore[arg-type] + _bboxes: list[Prompt] = [] + _points: list[Prompt] = [] + _polygons: list[Prompt] = [] + for prompt, label in zip(prompts, labels.prompts): # type: ignore[arg-type] if isinstance(prompt, tv_tensors.BoundingBoxes): - bboxes.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) + _bboxes.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) elif isinstance(prompt, Points): - points.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) - # TODO (sungchul): support polygons + _points.append(Prompt(prompt.cpu().numpy(), label.cpu().numpy())) + + if polygons and labels.polygons is not None: + for polygon, label in zip(polygons, labels.polygons): + _polygons.append(Prompt(np.array(polygon.points, dtype=np.int32), label.cpu().numpy())) + + # TODO (sungchul, sovrasov): support mask? # preprocess decoder inputs processed_prompts.append( { - "boxes": bboxes, - "points": points, + "boxes": _bboxes, + "points": _points, + "polygons": _polygons, }, ) @@ -1047,7 +1061,7 @@ def transform_fn( _labels: dict[str, list[int]] = defaultdict(list) # use only the first prompt - for prompt, label in zip(data_batch.prompts[0], data_batch.labels[0]): # type: ignore[arg-type] + for prompt, label in zip(data_batch.prompts[0], data_batch.labels[0].prompts): # type: ignore[arg-type] if isinstance(prompt, tv_tensors.BoundingBoxes): bboxes.append(prompt.cpu().numpy()) _labels["bboxes"].append(label.cpu().numpy()) diff --git a/tests/unit/algo/visual_prompting/conftest.py b/tests/unit/algo/visual_prompting/conftest.py index f9567775649..8e31db96703 100644 --- a/tests/unit/algo/visual_prompting/conftest.py +++ b/tests/unit/algo/visual_prompting/conftest.py @@ -1,122 +1,6 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations +from tests.unit.core.conftest import fxt_vpm_data_entity, fxt_zero_shot_vpm_data_entity -import pytest -import torch -from otx.core.data.entity.base import ImageInfo, Points -from otx.core.data.entity.visual_prompting import ( - VisualPromptingBatchDataEntity, - VisualPromptingBatchPredEntity, - VisualPromptingDataEntity, - ZeroShotVisualPromptingBatchDataEntity, - ZeroShotVisualPromptingBatchPredEntity, - ZeroShotVisualPromptingDataEntity, -) -from torchvision import tv_tensors - - -@pytest.fixture(scope="session") -def fxt_vpm_data_entity() -> ( - tuple[VisualPromptingDataEntity, VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity] -): - img_size = (1024, 1024) - fake_image = tv_tensors.Image(torch.rand(img_size)) - fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) - fake_bboxes = tv_tensors.BoundingBoxes( - [[0, 0, 1, 1]], - format=tv_tensors.BoundingBoxFormat.XYXY, - canvas_size=img_size, - dtype=torch.float32, - ) - fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) - fake_masks = tv_tensors.Mask(torch.rand(img_size)) - fake_labels = {"bboxes": torch.as_tensor([1], dtype=torch.int64)} - fake_polygons = [None] - # define data entity - single_data_entity = VisualPromptingDataEntity( - image=fake_image, - img_info=fake_image_info, - masks=fake_masks, - labels=fake_labels, - polygons=fake_polygons, - bboxes=fake_bboxes, - points=fake_points, - ) - batch_data_entity = VisualPromptingBatchDataEntity( - batch_size=1, - images=[fake_image], - imgs_info=[fake_image_info], - masks=[fake_masks], - labels=[fake_labels], - polygons=[fake_polygons], - bboxes=[fake_bboxes], - points=[fake_points], - ) - batch_pred_data_entity = VisualPromptingBatchPredEntity( - batch_size=1, - images=[fake_image], - imgs_info=[fake_image_info], - masks=[fake_masks], - labels=[fake_labels], - polygons=[fake_polygons], - bboxes=[fake_bboxes], - points=[fake_points], - scores=[], - ) - - return single_data_entity, batch_data_entity, batch_pred_data_entity - - -@pytest.fixture(scope="session") -def fxt_zero_shot_vpm_data_entity() -> ( - tuple[ - ZeroShotVisualPromptingDataEntity, - ZeroShotVisualPromptingBatchDataEntity, - ZeroShotVisualPromptingBatchPredEntity, - ] -): - img_size = (1024, 1024) - fake_image = tv_tensors.Image(torch.rand(img_size)) - fake_image_info = ImageInfo(img_idx=0, img_shape=img_size, ori_shape=img_size) - fake_bboxes = tv_tensors.BoundingBoxes( - [[0, 0, 1, 1]], - format=tv_tensors.BoundingBoxFormat.XYXY, - canvas_size=img_size, - dtype=torch.float32, - ) - fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) - fake_masks = tv_tensors.Mask(torch.rand(img_size)) - fake_labels = torch.as_tensor([1], dtype=torch.int64) - fake_polygons = [None] - # define data entity - single_data_entity = ZeroShotVisualPromptingDataEntity( - image=fake_image, - img_info=fake_image_info, - masks=fake_masks, - labels=fake_labels, - polygons=fake_polygons, - prompts=[fake_bboxes, fake_points], - ) - batch_data_entity = ZeroShotVisualPromptingBatchDataEntity( - batch_size=1, - images=[fake_image], - imgs_info=[fake_image_info], - masks=[fake_masks], - labels=[fake_labels], - polygons=[fake_polygons], - prompts=[[fake_bboxes, fake_points]], - ) - batch_pred_data_entity = ZeroShotVisualPromptingBatchPredEntity( - batch_size=1, - images=[fake_image], - imgs_info=[fake_image_info], - masks=[fake_masks], - labels=[fake_labels], - polygons=[fake_polygons], - prompts=[[fake_bboxes, fake_points]], - scores=[], - ) - - return single_data_entity, batch_data_entity, batch_pred_data_entity +__all__ = ["fxt_vpm_data_entity", "fxt_zero_shot_vpm_data_entity"] diff --git a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py index b9fe3a5e58e..90dbddd65c2 100644 --- a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py +++ b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py @@ -1,8 +1,9 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from copy import deepcopy from typing import Any, Callable import pytest @@ -576,17 +577,16 @@ def test_customize_outputs(self, model, fxt_zero_shot_vpm_data_entity) -> None: assert result.labels[0] == label assert torch.all(result.prompts[0].data == outputs[0][1][label][0][:2].unsqueeze(0)) - def test_gather_prompts_with_labels(self, model) -> None: + def test_gather_prompts_with_labels(self, model, fxt_zero_shot_vpm_data_entity) -> None: """Test _gather_prompts_with_labels.""" - prompts = [[torch.tensor(0), torch.tensor(1), torch.tensor(2), torch.tensor(2), torch.tensor(4)]] - labels = [torch.tensor([0, 1, 2, 2, 4])] + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) - results = model._gather_prompts_with_labels(prompts, labels) + results = model._gather_prompts_with_labels(entity) - assert results[0][0][0] == prompts[0][0] - assert results[0][1][0] == prompts[0][1] - assert results[0][2] == prompts[0][2:4] - assert results[0][4][0] == prompts[0][4] + assert torch.all(results[0][1][0] == entity.prompts[0][0]) + assert torch.all(results[0][1][1] == entity.masks[0]) + assert torch.all(results[0][2][0] == entity.prompts[0][1]) + assert results[0][2][1] == entity.polygons[0][0] @pytest.mark.parametrize( ("image", "expected"), diff --git a/tests/unit/core/conftest.py b/tests/unit/core/conftest.py index d571d9630bf..0dacf5fd15a 100644 --- a/tests/unit/core/conftest.py +++ b/tests/unit/core/conftest.py @@ -5,7 +5,7 @@ import numpy as np import pytest import torch -from datumaro import Label +from datumaro import Label, Polygon from datumaro.components.annotation import AnnotationType, LabelCategories from datumaro.components.dataset import Dataset, DatasetItem from datumaro.components.media import Image @@ -18,6 +18,7 @@ ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity, ZeroShotVisualPromptingDataEntity, + ZeroShotVisualPromptingLabel, ) from torchvision import tv_tensors @@ -151,8 +152,12 @@ def fxt_zero_shot_vpm_data_entity() -> ( ) fake_points = Points([[2, 2]], canvas_size=img_size, dtype=torch.float32) fake_masks = tv_tensors.Mask(torch.ones(1, *img_size)) - fake_labels = torch.as_tensor([1, 2], dtype=torch.int64) - fake_polygons = [None] + fake_labels = ZeroShotVisualPromptingLabel( + prompts=torch.as_tensor([1, 2], dtype=torch.int64), + masks=torch.as_tensor([1], dtype=torch.int64), + polygons=torch.as_tensor([2], dtype=torch.int64), + ) + fake_polygons = [Polygon(points=[1, 1, 1, 2, 2, 2, 2, 1])] fake_scores = torch.tensor([[1.0]]) # define data entity single_data_entity = ZeroShotVisualPromptingDataEntity( @@ -177,7 +182,7 @@ def fxt_zero_shot_vpm_data_entity() -> ( images=[fake_image], imgs_info=[fake_image_info], masks=[fake_masks], - labels=[fake_labels], + labels=[fake_labels.prompts], polygons=[fake_polygons], prompts=[[fake_bboxes, fake_points]], scores=[fake_scores], diff --git a/tests/unit/core/data/dataset/test_visual_prompting.py b/tests/unit/core/data/dataset/test_visual_prompting.py index 6ff53b96014..a651cf1589b 100644 --- a/tests/unit/core/data/dataset/test_visual_prompting.py +++ b/tests/unit/core/data/dataset/test_visual_prompting.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Unit tests of visual prompting datasets.""" @@ -9,7 +9,7 @@ from datumaro import Dataset as DmDataset from otx.core.data.dataset.visual_prompting import OTXVisualPromptingDataset, OTXZeroShotVisualPromptingDataset from otx.core.data.entity.base import ImageInfo, Points -from torch import Tensor +from otx.core.data.entity.visual_prompting import ZeroShotVisualPromptingLabel from torchvision.transforms.v2 import Identity, Transform from torchvision.tv_tensors import BoundingBoxes, Image, Mask @@ -103,7 +103,7 @@ def test_get_item_impl_subset( assert hasattr(entity, "masks") assert isinstance(entity.masks, Mask) assert hasattr(entity, "labels") - assert isinstance(entity.labels, Tensor) + assert isinstance(entity.labels, ZeroShotVisualPromptingLabel) assert hasattr(entity, "polygons") assert isinstance(entity.polygons, list) assert hasattr(entity, "prompts") diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index 93f169cfdd8..30e9c8a04bc 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -65,17 +65,15 @@ def test_inference_step(mocker, otx_visual_prompting_model, fxt_vpm_data_entity) def test_inference_step_for_zero_shot(mocker, otx_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: """Test _inference_step_for_zero_shot.""" + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) + pred_entity = deepcopy(fxt_zero_shot_vpm_data_entity[2]) otx_visual_prompting_model.configure_metric() - mocker.patch.object(otx_visual_prompting_model, "forward", return_value=fxt_zero_shot_vpm_data_entity[2]) + mocker.patch.object(otx_visual_prompting_model, "forward", return_value=pred_entity) mocker_updates = {} for k, v in otx_visual_prompting_model.metric.items(): mocker_updates[k] = mocker.patch.object(v, "update") - _inference_step_for_zero_shot( - otx_visual_prompting_model, - otx_visual_prompting_model.metric, - fxt_zero_shot_vpm_data_entity[1], - ) + _inference_step_for_zero_shot(otx_visual_prompting_model, otx_visual_prompting_model.metric, entity) for v in mocker_updates.values(): v.assert_called_once() @@ -88,8 +86,10 @@ def test_inference_step_for_zero_shot_with_more_preds( ) -> None: """Test _inference_step_for_zero_shot with more preds.""" otx_visual_prompting_model.configure_metric() + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) + pred_entity = deepcopy(fxt_zero_shot_vpm_data_entity[2]) preds = {} - for k, v in fxt_zero_shot_vpm_data_entity[2].__dict__.items(): + for k, v in pred_entity.__dict__.items(): if k in ["batch_size", "polygons"]: preds[k] = v else: @@ -103,11 +103,7 @@ def test_inference_step_for_zero_shot_with_more_preds( for k, v in otx_visual_prompting_model.metric.items(): mocker_updates[k] = mocker.patch.object(v, "update") - _inference_step_for_zero_shot( - otx_visual_prompting_model, - otx_visual_prompting_model.metric, - fxt_zero_shot_vpm_data_entity[1], - ) + _inference_step_for_zero_shot(otx_visual_prompting_model, otx_visual_prompting_model.metric, entity) for v in mocker_updates.values(): v.assert_called_once() @@ -120,12 +116,14 @@ def test_inference_step_for_zero_shot_with_more_target( ) -> None: """Test _inference_step_for_zero_shot with more target.""" otx_visual_prompting_model.configure_metric() - mocker.patch.object(otx_visual_prompting_model, "forward", return_value=fxt_zero_shot_vpm_data_entity[2]) + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) + pred_entity = deepcopy(fxt_zero_shot_vpm_data_entity[2]) + mocker.patch.object(otx_visual_prompting_model, "forward", return_value=pred_entity) mocker_updates = {} for k, v in otx_visual_prompting_model.metric.items(): mocker_updates[k] = mocker.patch.object(v, "update") target = {} - for k, v in fxt_zero_shot_vpm_data_entity[1].__dict__.items(): + for k, v in entity.__dict__.items(): if k in ["batch_size"]: target[k] = v else: @@ -444,6 +442,7 @@ def test_forward( def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: """Test learn.""" + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) ov_zero_shot_visual_prompting_model.model._reference_features = np.zeros((0, 1, 256), dtype=np.float32) ov_zero_shot_visual_prompting_model.model._used_indices = np.array([], dtype=np.int64) ov_zero_shot_visual_prompting_model.model.decoder.mask_threshold = 0.0 @@ -456,16 +455,17 @@ def test_learn(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ return_value=np.random.rand(1, 256), ) reference_info, ref_masks = ov_zero_shot_visual_prompting_model.learn( - inputs=fxt_zero_shot_vpm_data_entity[1], + inputs=entity, reset_feat=False, ) - assert reference_info["reference_feats"].shape == torch.Size((2, 1, 256)) + assert reference_info["reference_feats"].shape == torch.Size((3, 1, 256)) assert 1 in reference_info["used_indices"] - assert ref_masks[0].shape == torch.Size((2, 1024, 1024)) + assert ref_masks[0].shape == torch.Size((3, 1024, 1024)) def test_infer(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity) -> None: """Test infer.""" + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) ov_zero_shot_visual_prompting_model.model.decoder.mask_threshold = 0.0 ov_zero_shot_visual_prompting_model.model.decoder.output_blob_name = "upscaled_masks" @@ -486,7 +486,7 @@ def test_infer(self, mocker, ov_zero_shot_visual_prompting_model, fxt_zero_shot_ used_indices = np.array([1]) results = ov_zero_shot_visual_prompting_model.infer( - inputs=fxt_zero_shot_vpm_data_entity[1], + inputs=entity, reference_feats=reference_feats, used_indices=used_indices, ) @@ -500,11 +500,12 @@ def test_customize_outputs_training( ov_zero_shot_visual_prompting_model, fxt_zero_shot_vpm_data_entity, ) -> None: + entity = deepcopy(fxt_zero_shot_vpm_data_entity[1]) ov_zero_shot_visual_prompting_model.training = True outputs = ({"foo": np.array(1), "bar": np.array(2)}, [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]) - result = ov_zero_shot_visual_prompting_model._customize_outputs(outputs, fxt_zero_shot_vpm_data_entity[1]) + result = ov_zero_shot_visual_prompting_model._customize_outputs(outputs, entity) assert result == outputs