Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiling Semantic Seg #3954

Merged
Merged
140 changes: 129 additions & 11 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import numpy as np
import shapely.geometry as sg
import torch
from datumaro import Bbox, DatasetItem, Image, Polygon
from datumaro import Dataset as DmDataset
from datumaro.components.annotation import AnnotationType
from datumaro import DatasetItem, Image
from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, Polygon
from datumaro.plugins.tiling import Tile
from datumaro.plugins.tiling.tile import _apply_offset
from datumaro.plugins.tiling.util import (
Expand All @@ -27,14 +27,18 @@
)
from torchvision import tv_tensors

from otx.core.data.dataset.segmentation import _extract_class_mask
from otx.core.data.entity.base import ImageInfo
from otx.core.data.entity.detection import DetDataEntity
from otx.core.data.entity.instance_segmentation import InstanceSegDataEntity
from otx.core.data.entity.segmentation import SegDataEntity
from otx.core.data.entity.tile import (
TileBatchDetDataEntity,
TileBatchInstSegDataEntity,
TileBatchSegDataEntity,
TileDetDataEntity,
TileInstSegDataEntity,
TileSegDataEntity,
)
from otx.core.types.task import OTXTaskType
from otx.core.utils.mask_util import polygon_to_bitmap
Expand All @@ -47,6 +51,7 @@
from otx.core.config.data import TileConfig
from otx.core.data.dataset.detection import OTXDetectionDataset
from otx.core.data.dataset.instance_segmentation import OTXInstanceSegDataset
from otx.core.data.dataset.segmentation import OTXSegmentationDataset
from otx.core.data.entity.base import OTXDataEntity

# ruff: noqa: SLF001
Expand Down Expand Up @@ -87,6 +92,7 @@ def __init__(
)
self._tile_size = tile_size
self._tile_ann_func_map[AnnotationType.polygon] = OTXTileTransform._tile_polygon
self._tile_ann_func_map[AnnotationType.mask] = OTXTileTransform._tile_masks
self.with_full_img = with_full_img

@staticmethod
Expand Down Expand Up @@ -127,6 +133,30 @@ def _tile_polygon(
attributes=deepcopy(ann.attributes),
)

@staticmethod
def _tile_masks(
ann: ExtractedMask,
roi_int: BboxIntCoords,
*args, # noqa: ARG004
**kwargs, # noqa: ARG004
) -> ExtractedMask:
"""Extracts a tile mask from the given annotation.

Note: Original Datumaro _tile_masks does not work with ExtractedMask.

Args:
ann (ExtractedMask): datumaro ExtractedMask annotation.
roi_int (BboxIntCoords): ROI coordinates.

Returns:
ExtractedMask: ExtractedMask annotation.
"""
x, y, w, h = roi_int
return ann.wrap(
index_mask=ann.index_mask()[y : y + h, x : x + w],
attributes=deepcopy(ann.attributes),
)

def _extract_rois(self, image: Image) -> list[BboxIntCoords]:
"""Extracts Tile ROIs from the given image.

Expand Down Expand Up @@ -195,6 +225,9 @@ def create(
return OTXTileDetTestDataset(dataset, tile_config)
if task in [OTXTaskType.ROTATED_DETECTION, OTXTaskType.INSTANCE_SEGMENTATION]:
return OTXTileInstSegTestDataset(dataset, tile_config)
if task == OTXTaskType.SEMANTIC_SEGMENTATION:
return OTXTileSemanticSegTestDataset(dataset, tile_config)

msg = f"Unsupported task type: {task} for tiling"
raise NotImplementedError(msg)

Expand All @@ -218,6 +251,15 @@ def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None:
self.tile_config = tile_config
self._dataset = dataset

# LabelInfo differs from SegLabelInfo, thus we need to update it for semantic segmentation.
if self.label_info != dataset.label_info:
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
msg = (
"Replace the label info to match the dataset's label info",
"as there is a mismatch between the dataset and the tile dataset.",
)
log.warning(msg)
self.label_info = dataset.label_info

def __len__(self) -> int:
return len(self._dataset)

Expand All @@ -230,17 +272,23 @@ def _get_item_impl(self, index: int) -> OTXDataEntity | None:
"""Get item implementation from the original dataset."""
return self._dataset._get_item_impl(index)

def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem) -> OTXDataEntity:
def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> OTXDataEntity:
"""Convert a tile dataset item to OTXDataEntity."""
msg = "Method _convert_entity is not implemented."
raise NotImplementedError(msg)

def get_tiles(self, image: np.ndarray, item: DatasetItem) -> tuple[list[OTXDataEntity], list[dict]]:
def get_tiles(
self,
image: np.ndarray,
item: DatasetItem,
parent_idx: int,
eugene123tw marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[list[OTXDataEntity], list[dict]]:
"""Retrieves tiles from the given image and dataset item.

Args:
image (np.ndarray): The input image.
item (DatasetItem): The dataset item.
parent_idx (int): The parent index. This is to keep track of the original dataset item index for merging.

Returns:
A tuple containing two lists:
Expand All @@ -263,12 +311,13 @@ def get_tiles(self, image: np.ndarray, item: DatasetItem) -> tuple[list[OTXDataE
tile_entities: list[OTXDataEntity] = []
tile_attrs: list[dict] = []
for tile in tile_ds:
tile_entity = self._convert_entity(image, tile)
tile_entity = self._convert_entity(image, tile, parent_idx)
# apply the same transforms as the original dataset
transformed_tile = self._apply_transforms(tile_entity)
if transformed_tile is None:
msg = "Transformed tile is None"
raise RuntimeError(msg)
tile.attributes.update({"tile_size": self.tile_config.tile_size})
tile_entities.append(transformed_tile)
tile_attrs.append(tile.attributes)
return tile_entities, tile_attrs
Expand Down Expand Up @@ -346,7 +395,7 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
)
labels = torch.as_tensor([ann.label for ann in bbox_anns])

tile_entities, tile_attrs = self.get_tiles(img_data, item)
tile_entities, tile_attrs = self.get_tiles(img_data, item, index)

return TileDetDataEntity(
num_tiles=len(tile_entities),
Expand All @@ -365,13 +414,13 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
ori_labels=labels,
)

def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem) -> DetDataEntity:
def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> DetDataEntity:
"""Convert a tile datumaro dataset item to DetDataEntity."""
x1, y1, w, h = dataset_item.attributes["roi"]
tile_img = image[y1 : y1 + h, x1 : x1 + w]
tile_shape = tile_img.shape[:2]
img_info = ImageInfo(
img_idx=dataset_item.attributes["id"],
img_idx=parent_idx,
img_shape=tile_shape,
ori_shape=tile_shape,
)
Expand Down Expand Up @@ -448,7 +497,7 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
masks = np.stack(gt_masks, axis=0) if gt_masks else np.zeros((0, *img_shape), dtype=bool)
labels = np.array(gt_labels, dtype=np.int64)

tile_entities, tile_attrs = self.get_tiles(img_data, item)
tile_entities, tile_attrs = self.get_tiles(img_data, item, index)

return TileInstSegDataEntity(
num_tiles=len(tile_entities),
Expand All @@ -469,13 +518,13 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
ori_polygons=gt_polygons,
)

def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem) -> InstanceSegDataEntity:
def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> InstanceSegDataEntity:
"""Convert a tile dataset item to InstanceSegDataEntity."""
x1, y1, w, h = dataset_item.attributes["roi"]
tile_img = image[y1 : y1 + h, x1 : x1 + w]
tile_shape = tile_img.shape[:2]
img_info = ImageInfo(
img_idx=dataset_item.attributes["id"],
img_idx=parent_idx,
img_shape=tile_shape,
ori_shape=tile_shape,
)
Expand All @@ -492,3 +541,72 @@ def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem) -> Insta
masks=tv_tensors.Mask(np.zeros((0, *tile_shape), dtype=bool)),
polygons=[],
)


class OTXTileSemanticSegTestDataset(OTXTileDataset):
"""OTX tile semantic-seg test dataset.

OTXTileSemanticSegTestDataset wraps a list of tiles (SegDataEntity) into a single TileSegDataEntity
for testing/predicting.

Args:
dataset (OTXSegmentationDataset): OTX semantic-seg dataset.
tile_config (TilerConfig): Tile configuration.
"""

def __init__(self, dataset: OTXSegmentationDataset, tile_config: TileConfig) -> None:
super().__init__(dataset, tile_config)
self.ignore_index = self._dataset.ignore_index

@property
def collate_fn(self) -> Callable:
"""Collate function for tile detection test dataset."""
return TileBatchSegDataEntity.collate_fn

def _get_item_impl(self, index: int) -> TileSegDataEntity: # type: ignore[override]
"""Get item implementation.

Transform a single dataset item to multiple tiles using Datumaro tiling plugin, and
wrap tiles into a single TileSegDataEntity.

Args:
index (int): Index of the dataset item.

Returns:
TileSegDataEntity: tile semantic-seg data entity that wraps a list of semantic-seg data entities.
"""
item = self.dm_subset[index]
img = item.media_as(Image)
img_data, img_shape = self._get_img_data_and_shape(img)

extracted_mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index)
masks = tv_tensors.Mask(extracted_mask[None])
tile_entities, tile_attrs = self.get_tiles(img_data, item, index)

return TileSegDataEntity(
num_tiles=len(tile_entities),
entity_list=tile_entities,
tile_attr_list=tile_attrs,
ori_img_info=ImageInfo(
img_idx=index,
img_shape=img_shape,
ori_shape=img_shape,
),
ori_masks=masks,
)

def _convert_entity(self, image: np.ndarray, dataset_item: DatasetItem, parent_idx: int) -> SegDataEntity:
"""Convert a tile datumaro dataset item to SegDataEntity."""
x1, y1, w, h = dataset_item.attributes["roi"]
tile_img = image[y1 : y1 + h, x1 : x1 + w]
tile_shape = tile_img.shape[:2]
img_info = ImageInfo(
img_idx=parent_idx,
img_shape=tile_shape,
ori_shape=tile_shape,
)
return SegDataEntity(
image=tile_img,
img_info=img_info,
masks=tv_tensors.Mask(np.zeros((0, *tile_shape), dtype=bool)),
)
82 changes: 81 additions & 1 deletion src/otx/core/data/entity/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Sequence

import torch
from torchvision import tv_tensors

from otx.core.data.entity.utils import stack_batch
from otx.core.types.task import OTXTaskType

from .base import ImageInfo, T_OTXBatchDataEntity, T_OTXDataEntity
from .detection import DetBatchDataEntity, DetDataEntity
from .instance_segmentation import InstanceSegBatchDataEntity, InstanceSegDataEntity
from .segmentation import SegBatchDataEntity, SegDataEntity

if TYPE_CHECKING:
from datumaro import Polygon
from torch import LongTensor
from torchvision import tv_tensors


@dataclass
Expand Down Expand Up @@ -252,3 +255,80 @@ def collate_fn(cls, batch_entities: list[TileInstSegDataEntity]) -> TileBatchIns
masks=[tile_entity.ori_masks for tile_entity in batch_entities],
polygons=[tile_entity.ori_polygons for tile_entity in batch_entities],
)


@dataclass
class TileSegDataEntity(TileDataEntity):
"""Data entity for segmentation tile task.

Attributes:
ori_masks (tv_tensors.Mask): The masks of the original image.
"""

ori_masks: tv_tensors.Mask

@property
def task(self) -> OTXTaskType:
"""OTX Task type definition."""
return OTXTaskType.SEMANTIC_SEGMENTATION


@dataclass
class TileBatchSegDataEntity(OTXTileBatchDataEntity):
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
"""Batch data entity for semantic segmentation tile task.

Attributes:
masks (list[tv_tensors.Mask]): The masks of the original image.
"""

masks: list[tv_tensors.Mask]

def unbind(self) -> list[tuple[list[dict[str, int | str]], SegBatchDataEntity]]:
"""Unbind batch data entity for semantic segmentation task."""
tiles = [tile for tiles in self.batch_tiles for tile in tiles]
tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos]
tile_attr_list = [tile_attr for tile_attrs in self.batch_tile_attr_list for tile_attr in tile_attrs]

batch_tile_attr_list = [
tile_attr_list[i : i + self.batch_size] for i in range(0, len(tile_attr_list), self.batch_size)
]
batch_data_entities = [
SegBatchDataEntity(
batch_size=self.batch_size,
images=tv_tensors.wrap(torch.stack(tiles[i : i + self.batch_size]), like=tiles[0]),
imgs_info=tile_infos[i : i + self.batch_size],
masks=[[] for _ in range(self.batch_size)],
)
for i in range(0, len(tiles), self.batch_size)
]
return list(zip(batch_tile_attr_list, batch_data_entities))

@classmethod
def collate_fn(cls, batch_entities: list[TileSegDataEntity]) -> TileBatchSegDataEntity:
"""Collate function to collect TileSegDataEntity into TileBatchSegDataEntity in data loader."""
if (batch_size := len(batch_entities)) == 0:
msg = "collate_fn() input should have > 0 entities"
raise RuntimeError(msg)

task = batch_entities[0].task

for tile_entity in batch_entities:
for entity in tile_entity.entity_list:
if entity.task != task:
msg = "collate_fn() input should include a single OTX task"
raise RuntimeError(msg)

if not isinstance(entity, SegDataEntity):
msg = "All entities should be SegDataEntity before collate_fn()"
raise TypeError(msg)

return TileBatchSegDataEntity(
batch_size=batch_size,
batch_tiles=[[entity.image for entity in tile_entity.entity_list] for tile_entity in batch_entities],
batch_tile_img_infos=[
[entity.img_info for entity in tile_entity.entity_list] for tile_entity in batch_entities
],
batch_tile_attr_list=[tile_entity.tile_attr_list for tile_entity in batch_entities],
imgs_info=[tile_entity.ori_img_info for tile_entity in batch_entities],
masks=[tile_entity.ori_masks for tile_entity in batch_entities],
)
3 changes: 2 additions & 1 deletion src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__( # noqa: PLR0913
if adaptive_input_size is not None:
input_size = adapt_input_size_to_dataset(
dataset,
self.task,
input_size,
adaptive_input_size == "downscale",
input_size_multiplier,
Expand All @@ -149,7 +150,7 @@ def __init__( # noqa: PLR0913
self.input_size = input_size

if self.tile_config.enable_tiler and self.tile_config.enable_adaptive_tiling:
adapt_tile_config(self.tile_config, dataset=dataset)
adapt_tile_config(self.tile_config, dataset=dataset, task=self.task)

config_mapping = {
self.train_subset.subset_name: self.train_subset,
Expand Down
Loading
Loading