Skip to content

Commit

Permalink
Fix tiling test (#3637)
Browse files Browse the repository at this point in the history
* fix tile test

* add description to tile transform
  • Loading branch information
eugene123tw authored Jun 19, 2024
1 parent 340ea62 commit 4396b5f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
3 changes: 3 additions & 0 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class OTXTileTransform(Tile):
extractor (DmDataset): Dataset subset to extract tiles from.
tile_size (tuple[int, int]): Tile size.
overlap (tuple[float, float]): Overlap ratio.
Overlap values are clipped between 0 and 0.9 to ensure the stride is not too small.
threshold_drop_ann (float): Threshold to drop annotations.
with_full_img (bool): Include full image in the tiles.
"""
Expand All @@ -76,6 +77,8 @@ def __init__(
threshold_drop_ann: float,
with_full_img: bool,
) -> None:
# NOTE: clip overlap to [0, 0.9]
overlap = max(0, min(overlap[0], 0.9)), max(0, min(overlap[1], 0.9))
super().__init__(
extractor,
(0, 0),
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/tile_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def adapt_tile_config(tile_config: TileConfig, dataset: Dataset) -> None:

if tile_overlap >= 0.9:
# Use the average object area if the tile overlap is too large to prevent 0 stride.
tile_overlap = avg_size / tile_size
log.info(f"----> (too big) tile_overlap: {avg_size} / {tile_size} = {tile_overlap}")
tile_overlap = min(avg_size / tile_size, 0.9)
log.info(f"----> (too big) tile_overlap: {avg_size} / {tile_size} = min[{tile_overlap}, 0.9]")

# TODO(Eugene): how to validate lower/upper_bound? dataclass? pydantic?
# https://github.com/openvinotoolkit/training_extensions/pull/2903
Expand Down
38 changes: 18 additions & 20 deletions tests/unit/core/data/test_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from otx.core.model.detection import OTXDetectionModel
from otx.core.model.instance_segmentation import OTXInstanceSegModel
from otx.core.types.task import OTXTaskType
from otx.core.types.transformer_libs import TransformLibType
from torchvision import tv_tensors

from tests.test_helpers import generate_random_bboxes
Expand All @@ -39,12 +40,12 @@ def mock_otx_det_model(self) -> OTXDetectionModel:
return create_autospec(OTXDetectionModel)

@pytest.fixture()
def fxt_mmcv_det_transform_config(self) -> list[DictConfig]:
mmdet_base = OmegaConf.load("src/otx/recipe/_base_/data/mmdet_base.yaml")
def fxt_tv_det_transform_config(self) -> list[DictConfig]:
mmdet_base = OmegaConf.load("src/otx/recipe/_base_/data/torchvision_base.yaml")
return mmdet_base.config.train_subset.transforms

@pytest.fixture()
def fxt_det_data_config(self, fxt_mmcv_det_transform_config) -> OTXDataModule:
def fxt_det_data_config(self, fxt_tv_det_transform_config) -> OTXDataModule:
data_root = Path(__file__).parent.parent.parent.parent / "assets" / "car_tree_bug"

batch_size = 8
Expand All @@ -56,35 +57,31 @@ def fxt_det_data_config(self, fxt_mmcv_det_transform_config) -> OTXDataModule:
subset_name="train",
batch_size=batch_size,
num_workers=num_workers,
transform_lib_type="MMDET",
transforms=fxt_mmcv_det_transform_config,
transform_lib_type=TransformLibType.TORCHVISION,
transforms=fxt_tv_det_transform_config,
),
val_subset=SubsetConfig(
subset_name="val",
batch_size=batch_size,
num_workers=num_workers,
transform_lib_type="MMDET",
transforms=fxt_mmcv_det_transform_config,
transform_lib_type=TransformLibType.TORCHVISION,
transforms=fxt_tv_det_transform_config,
),
test_subset=SubsetConfig(
subset_name="test",
batch_size=batch_size,
num_workers=num_workers,
transform_lib_type="MMDET",
transforms=fxt_mmcv_det_transform_config,
transform_lib_type=TransformLibType.TORCHVISION,
transforms=fxt_tv_det_transform_config,
),
tile_config=TileConfig(),
vpm_config=VisualPromptingConfig(),
)

@pytest.fixture()
def fxt_instseg_data_config(self, fxt_mmcv_det_transform_config) -> OTXDataModule:
def fxt_instseg_data_config(self, fxt_tv_det_transform_config) -> OTXDataModule:
data_root = Path(__file__).parent.parent.parent.parent / "assets" / "car_tree_bug"

for transform in fxt_mmcv_det_transform_config:
if transform.type == "LoadAnnotations":
transform.with_mask = True

batch_size = 8
num_workers = 0
return DataModuleConfig(
Expand All @@ -94,22 +91,22 @@ def fxt_instseg_data_config(self, fxt_mmcv_det_transform_config) -> OTXDataModul
subset_name="train",
batch_size=batch_size,
num_workers=num_workers,
transform_lib_type="MMDET",
transforms=fxt_mmcv_det_transform_config,
transform_lib_type=TransformLibType.TORCHVISION,
transforms=fxt_tv_det_transform_config,
),
val_subset=SubsetConfig(
subset_name="val",
batch_size=batch_size,
num_workers=num_workers,
transform_lib_type="MMDET",
transforms=fxt_mmcv_det_transform_config,
transform_lib_type=TransformLibType.TORCHVISION,
transforms=fxt_tv_det_transform_config,
),
test_subset=SubsetConfig(
subset_name="test",
batch_size=batch_size,
num_workers=num_workers,
transform_lib_type="MMDET",
transforms=fxt_mmcv_det_transform_config,
transform_lib_type=TransformLibType.TORCHVISION,
transforms=fxt_tv_det_transform_config,
),
tile_config=TileConfig(),
vpm_config=VisualPromptingConfig(),
Expand Down Expand Up @@ -239,6 +236,7 @@ def test_tile_transform(self):
rng = np.random.default_rng()
tile_size = rng.integers(low=100, high=500, size=(2,))
overlap = rng.random(2)
overlap = overlap.clip(0, 0.9)
threshold_drop_ann = rng.random()
tiled_dataset = DmDataset.import_from("tests/assets/car_tree_bug", format="coco_instances")
tiled_dataset.transform(
Expand Down

0 comments on commit 4396b5f

Please sign in to comment.