Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to use polygon and bitmap mask as reference mask for zero-shot learning #3769

Merged
merged 14 commits into from
Jul 30, 2024
2 changes: 1 addition & 1 deletion src/otx/algo/visual_prompting/segment_anything.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Segment Anything model for the OTX visual prompting."""
Expand Down
58 changes: 38 additions & 20 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Segment Anything model for the OTX zero-shot visual prompting."""
Expand All @@ -19,7 +19,7 @@
from torch import Tensor, nn
from torch.nn import functional as F # noqa: N812
from torchvision import tv_tensors
from torchvision.tv_tensors import BoundingBoxes, Image, Mask, TVTensor
from torchvision.tv_tensors import BoundingBoxes, Image, Mask

from otx.algo.visual_prompting.segment_anything import DEFAULT_CONFIG_SEGMENT_ANYTHING, SegmentAnything
from otx.core.data.entity.base import OTXBatchLossEntity, Points
Expand All @@ -32,6 +32,7 @@
from otx.core.model.visual_prompting import OTXZeroShotVisualPromptingModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import LabelInfoTypes, NullLabelInfo
from otx.core.utils.mask_util import polygon_to_bitmap

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -230,7 +231,7 @@ def expand_reference_info(self, reference_feats: Tensor, new_largest_label: int)
def learn(
self,
images: list[Image],
processed_prompts: list[dict[int, list[TVTensor]]],
processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]],
reference_feats: Tensor,
used_indices: Tensor,
ori_shapes: list[Tensor],
Expand All @@ -244,15 +245,15 @@ def learn(

Args:
images (list[Image]): List of given images for reference features.
processed_prompts (dict[int, list[TVTensor]]): The class-wise prompts
processed_prompts (dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]): The class-wise prompts
processed at OTXZeroShotSegmentAnything._gather_prompts_with_labels.
reference_feats (Tensor): Reference features for target prediction.
used_indices (Tensor): To check which indices of reference features are validate.
ori_shapes (List[Tensor]): List of original shapes per image.
is_cascade (bool): Whether use cascade inference. Defaults to False.
"""
# initialize tensors to contain reference features and prompts
largest_label = max(sum([[int(p) for p in prompt] for prompt in processed_prompts], []))
largest_label = max([max(prompt.keys()) for prompt in processed_prompts])
reference_feats = self.expand_reference_info(reference_feats, largest_label)
new_used_indices: list[Tensor] = []
# TODO (sungchul): consider how to handle multiple reference features, currently replace it
Expand All @@ -270,20 +271,16 @@ def learn(
for input_prompt in input_prompts:
if isinstance(input_prompt, Mask):
# directly use annotation information as a mask
ref_mask[input_prompt == 1] += 1 # TODO (sungchul): check if the mask is bool or int
ref_mask[input_prompt] += 1
elif isinstance(input_prompt, dmPolygon):
ref_mask[torch.as_tensor(polygon_to_bitmap([input_prompt], *ori_shape)[0])] += 1
else:
if isinstance(input_prompt, BoundingBoxes):
point_coords = input_prompt.reshape(-1, 2, 2)
point_labels = torch.tensor([[2, 3]], device=point_coords.device)
elif isinstance(input_prompt, Points):
point_coords = input_prompt.reshape(-1, 1, 2)
point_labels = torch.tensor([[1]], device=point_coords.device)
elif isinstance(
input_prompt,
dmPolygon,
): # TODO (sungchul): add other polygon types
# TODO (sungchul): convert polygon to mask
continue
else:
log.info(f"Current input prompt ({input_prompt.__class__.__name__}) is not supported.")
continue
Expand Down Expand Up @@ -745,7 +742,14 @@ def _customize_inputs( # type: ignore[override]
if self.training:
# learn
forward_inputs.update(
{"processed_prompts": self._gather_prompts_with_labels(inputs.prompts, inputs.labels)},
{
"processed_prompts": self._gather_prompts_with_labels(
inputs.labels,
inputs.prompts,
inputs.polygons,
inputs.masks,
),
},
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
)

return forward_inputs
Expand Down Expand Up @@ -810,17 +814,29 @@ def _customize_outputs( # type: ignore[override]

def _gather_prompts_with_labels(
self,
prompts: list[list[TVTensor]],
labels: list[Tensor],
) -> list[dict[int, list[TVTensor]]]:
labels: list[dict[Literal["prompts", "polygons", "masks"], Tensor]],
prompts: list[list[BoundingBoxes | Points]] | None = None,
polygons: list[list[dmPolygon]] | None = None,
masks: list[Mask] | None = None,
) -> list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]]:
"""Gather prompts according to labels."""
total_processed_prompts: list[dict[int, list[TVTensor]]] = []
for prompt, label in zip(prompts, labels):
total_processed_prompts: list[dict[int, list[BoundingBoxes | Points | dmPolygon | Mask]]] = []
for batch, batch_labels in enumerate(labels):
processed_prompts = defaultdict(list)
for _prompt, _label in zip(prompt, label): # type: ignore[arg-type]
processed_prompts[int(_label)].append(_prompt)
for prompt_type, prompt_labels in batch_labels.items():
_prompts = locals()[prompt_type]
if _prompts is None:
continue
for idx, _label in enumerate(prompt_labels):
if prompt_type in ("prompts", "polygons"):
processed_prompts[int(_label)].append(_prompts[batch][idx])
else:
# for mask
processed_prompts[int(_label)].append(Mask(_prompts[batch][idx]))

sorted_processed_prompts = dict(sorted(processed_prompts.items(), key=lambda x: x))
total_processed_prompts.append(sorted_processed_prompts)

return total_processed_prompts

def apply_image(self, image: Image | np.ndarray, target_length: int = 1024) -> Image:
Expand Down Expand Up @@ -893,6 +909,8 @@ def transforms(self, entity: ZeroShotVisualPromptingBatchDataEntity) -> ZeroShot
self.apply_prompts(prompt, info.ori_shape, self.model.image_size)
for prompt, info in zip(entity.prompts, entity.imgs_info)
],
masks=entity.masks,
polygons=entity.polygons,
)

def initialize_reference_info(self) -> None:
Expand Down
69 changes: 43 additions & 26 deletions src/otx/core/data/dataset/visual_prompting.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

"""Module for OTXVisualPromptingDataset."""

from __future__ import annotations

from collections import defaultdict
from functools import partial
from typing import Callable
from typing import Callable, Literal

import torch
import torchvision.transforms.v2.functional as F # noqa: N812
from datumaro import Bbox as dmBbox
from datumaro import Dataset as dmDataset
from datumaro import Image as dmImage
from datumaro import Mask as dmMask
from datumaro import Points as dmPoints
from datumaro import Polygon as dmPolygon
from torchvision import tv_tensors
from torchvision.transforms.v2.functional import convert_bounding_box_format, to_image
from torchvision.tv_tensors import BoundingBoxes as tvBoundingBoxes
from torchvision.tv_tensors import Mask as tvMask

from otx.core.data.entity.base import ImageInfo, Points
from otx.core.data.entity.visual_prompting import (
Expand All @@ -38,6 +40,14 @@ class OTXVisualPromptingDataset(OTXDataset[VisualPromptingDataEntity]):
Args:
dm_subset (dmDataset): The subset of the dataset.
transforms (Transforms): Data transformations to be applied.
use_bbox (bool): Whether to use bounding box prompt.
If both use_bbox and use_point are False, use_bbox is set to True as default.
If both are True, divide the probability into both.
Defaults to True.
use_point (bool): Whether to use point prompt.
If both use_bbox and use_point are False, use_bbox is set to True as default.
If both are True, divide the probability into both.
Defaults to False.
**kwargs: Additional keyword arguments passed to the base class.
"""

Expand Down Expand Up @@ -76,24 +86,21 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None:

for annotation in item.annotations:
if isinstance(annotation, dmPolygon):
mask = tv_tensors.Mask(polygon_to_bitmap([annotation], *img_shape)[0])
mask = tvMask(polygon_to_bitmap([annotation], *img_shape)[0])
mask_points = torch.nonzero(mask)
if len(mask_points[0]) == 0:
# skip very small region
continue

if torch.rand(1) < self.prob:
# get bbox
bbox = tv_tensors.BoundingBoxes(
bbox = tvBoundingBoxes(
annotation.get_bbox(),
format=tv_tensors.BoundingBoxFormat.XYWH,
format="xywh",
canvas_size=img_shape,
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
dtype=torch.float32,
)
bbox = F._meta.convert_bounding_box_format( # noqa: SLF001
bbox,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
)
bbox = convert_bounding_box_format(bbox, new_format="xyxy")
gt_bboxes.append(bbox)
gt_labels["bboxes"].append(annotation.label)
gt_masks["bboxes"].append(mask)
Expand Down Expand Up @@ -127,7 +134,7 @@ def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None:
bboxes = tv_tensors.wrap(torch.cat(gt_bboxes, dim=0), like=gt_bboxes[0]) if len(gt_bboxes) > 0 else None
points = tv_tensors.wrap(torch.stack(gt_points, dim=0), like=gt_points[0]) if len(gt_points) > 0 else None
labels = {prompt_type: torch.as_tensor(values, dtype=torch.int64) for prompt_type, values in gt_labels.items()}
masks = tv_tensors.Mask(
masks = tvMask(
torch.stack(gt_masks.get("bboxes", []) + gt_masks.get("points", []), dim=0),
dtype=torch.uint8,
)
Expand Down Expand Up @@ -168,6 +175,14 @@ class OTXZeroShotVisualPromptingDataset(OTXDataset[ZeroShotVisualPromptingDataEn
Args:
dm_subset (dmDataset): The subset of the dataset.
transforms (Transforms): Data transformations to be applied.
use_bbox (bool): Whether to use bounding box prompt.
If both use_bbox and use_point are False, use_bbox is set to True as default.
If both are True, divide the probability into both.
Defaults to True.
use_point (bool): Whether to use point prompt.
If both use_bbox and use_point are False, use_bbox is set to True as default.
If both are True, divide the probability into both.
Defaults to False.
**kwargs: Additional keyword arguments passed to the base class.
"""

Expand Down Expand Up @@ -199,28 +214,28 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None
img = item.media_as(dmImage)
img_data, img_shape = self._get_img_data_and_shape(img)

gt_prompts, gt_masks, gt_polygons, gt_labels = [], [], [], []
gt_prompts: list[tvBoundingBoxes | Points] = []
gt_masks: list[tvMask] = []
gt_polygons: list[dmPolygon] = []
gt_labels: dict[Literal["prompts", "polygons", "masks"], list[int]] = defaultdict(list)
for annotation in item.annotations:
if isinstance(annotation, dmPolygon):
# generate prompts from polygon
mask = tv_tensors.Mask(polygon_to_bitmap([annotation], *img_shape)[0])
mask = tvMask(polygon_to_bitmap([annotation], *img_shape)[0])
mask_points = torch.nonzero(mask)
if len(mask_points[0]) == 0:
# skip very small region
continue

if torch.rand(1) < self.prob:
# get bbox
bbox = tv_tensors.BoundingBoxes(
bbox = tvBoundingBoxes(
annotation.get_bbox(),
format=tv_tensors.BoundingBoxFormat.XYWH,
format="xywh",
canvas_size=img_shape,
dtype=torch.float32,
)
bbox = F._meta.convert_bounding_box_format( # noqa: SLF001
bbox,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
)
bbox = convert_bounding_box_format(bbox, new_format="xyxy")
gt_prompts.append(bbox)
else:
# get center point
Expand All @@ -231,22 +246,24 @@ def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None
)
gt_prompts.append(point)

gt_labels.append(annotation.label)
gt_labels["prompts"].append(annotation.label)
gt_labels["polygons"].append(annotation.label)
gt_labels["masks"].append(annotation.label)
gt_masks.append(mask)
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
gt_polygons.append(annotation)

# TODO(sungchul): for mask, bounding box, and point annotation
elif isinstance(annotation, (dmBbox, dmMask, dmPoints)):
pass

assert len(gt_prompts) > 0, "#prompts must be greater than 0." # noqa: S101
if len(gt_prompts) == 0:
return None
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved

labels = torch.as_tensor(gt_labels, dtype=torch.int64)
masks = tv_tensors.Mask(torch.stack(gt_masks, dim=0), dtype=torch.uint8)
labels = {prompt_type: torch.as_tensor(values, dtype=torch.int64) for prompt_type, values in gt_labels.items()}
masks = tvMask(torch.stack(gt_masks, dim=0), dtype=torch.uint8)

# set entity without masks to avoid resizing masks
return ZeroShotVisualPromptingDataEntity(
image=F.to_image(img_data),
image=to_image(img_data),
img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
Expand Down
Loading
Loading