diff --git a/src/otx/core/data/dataset/tile.py b/src/otx/core/data/dataset/tile.py index 8b799a44768..a13717643c1 100644 --- a/src/otx/core/data/dataset/tile.py +++ b/src/otx/core/data/dataset/tile.py @@ -6,14 +6,18 @@ from __future__ import annotations import logging as log +from copy import deepcopy from itertools import product from typing import TYPE_CHECKING, Callable import numpy as np +import shapely.geometry as sg import torch from datumaro import Bbox, DatasetItem, DatasetSubset, Image, Polygon from datumaro import Dataset as DmDataset +from datumaro.components.annotation import AnnotationType from datumaro.plugins.tiling import Tile +from datumaro.plugins.tiling.tile import _apply_offset from datumaro.plugins.tiling.util import ( clip_x1y1x2y2, cxcywh_to_x1y1x2y2, @@ -76,10 +80,45 @@ def __init__( threshold_drop_ann=threshold_drop_ann, ) self._tile_size = tile_size - # TODO (Eugene): Bug found in original Datumaro tile polygon function. - # https://github.com/eugene123tw/training_extensions/tree/eugene/fix-tile-polygon-func - # It lacks polygon validation, potentially leading to GeometryCollection or MultiPolygon results, - # which the current function doesn't handle. + self._tile_ann_func_map[AnnotationType.polygon] = OTXTileTransform._tile_polygon + + @staticmethod + def _tile_polygon( + ann: Polygon, + roi_box: sg.Polygon, + threshold_drop_ann: float = 0.8, + *args, # noqa: ARG004 + **kwargs, # noqa: ARG004 + ) -> Polygon | None: + polygon = sg.Polygon(ann.get_points()) + + # NOTE: polygon may be invalid, e.g. self-intersecting + if not roi_box.intersects(polygon) or not polygon.is_valid: + return None + + # NOTE: intersection may return a GeometryCollection or MultiPolygon + inter = polygon.intersection(roi_box) + if isinstance(inter, (sg.GeometryCollection, sg.MultiPolygon)): + shapes = [(geom, geom.area) for geom in list(inter.geoms) if geom.is_valid] + shapes.sort(key=lambda x: x[1], reverse=True) + if shapes: + inter = shapes[0][0] + if not isinstance(inter, sg.Polygon) and not inter.is_valid: + return None + else: + return None + + prop_area = inter.area / polygon.area + + if prop_area < threshold_drop_ann: + return None + + inter = _apply_offset(inter, roi_box) + + return ann.wrap( + points=[p for xy in inter.exterior.coords for p in xy], + attributes=deepcopy(ann.attributes), + ) def _extract_rois(self, image: Image) -> list[BboxIntCoords]: """Extracts Tile ROIs from the given image. diff --git a/tests/unit/core/data/test_tiling.py b/tests/unit/core/data/test_tiling.py index 08be1d7072b..184c3b30de0 100644 --- a/tests/unit/core/data/test_tiling.py +++ b/tests/unit/core/data/test_tiling.py @@ -9,8 +9,10 @@ import numpy as np import pytest +import shapely.geometry as sg import torch from datumaro import Dataset as DmDataset +from datumaro import Polygon from omegaconf import DictConfig, OmegaConf from otx.core.config.data import ( DataModuleConfig, @@ -136,6 +138,22 @@ def test_tile_transform(self): num_tile_cols = (width + w_stride - 1) // w_stride assert len(tiled_dataset) == (num_tile_rows * num_tile_cols * len(dataset)), "Incorrect number of tiles" + def test_tile_polygon_func(self): + points = np.array([(1, 2), (3, 5), (4, 2), (4, 6), (1, 6)]) + polygon = Polygon(points=points.flatten().tolist()) + roi = sg.Polygon([(0, 0), (5, 0), (5, 5), (0, 5)]) + + inter_polygon = OTXTileTransform._tile_polygon(polygon, roi, threshold_drop_ann=0.0) + assert isinstance(inter_polygon, Polygon), "Intersection should be a Polygon" + assert inter_polygon.get_area() > 0, "Intersection area should be greater than 0" + + assert ( + OTXTileTransform._tile_polygon(polygon, roi, threshold_drop_ann=1.0) is None + ), "Intersection should be None" + + invalid_polygon = Polygon(points=[0, 0, 5, 0, 5, 5, 5, 0]) + assert OTXTileTransform._tile_polygon(invalid_polygon, roi) is None, "Invalid polygon should be None" + def test_adaptive_tiling(self, fxt_det_data_config): # Enable tile adapter fxt_det_data_config.tile_config.enable_tiler = True