diff --git a/CHANGELOG.md b/CHANGELOG.md index cdf5c5840d5..5f35d17ce7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. ## \[Unreleased\] +### New features + +- Turn on/off classification augmentations + (https://github.com/openvinotoolkit/training_extensions/pull/4039) + ### Enhancements - Update visual prompting pipeline for multi-label zero-shot learning support diff --git a/src/otx/core/data/transform_libs/torchvision.py b/src/otx/core/data/transform_libs/torchvision.py index 8cfe3ea1636..1c77ca2eb8e 100644 --- a/src/otx/core/data/transform_libs/torchvision.py +++ b/src/otx/core/data/transform_libs/torchvision.py @@ -1377,7 +1377,7 @@ def forward(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None: inputs.image = img inputs.img_info = _resize_image_info(inputs.img_info, img.shape[:2]) - bboxes = inputs.bboxes + bboxes = getattr(inputs, "bboxes", []) num_bboxes = len(bboxes) if num_bboxes: bboxes = project_bboxes(bboxes, warp_matrix) @@ -3471,6 +3471,8 @@ def generate(cls, config: SubsetConfig) -> Compose: transforms = [] for cfg_transform in config.transforms: if isinstance(cfg_transform, (dict, DictConfig)): + if not cfg_transform.get("enable", True): # Optional "enable: false" flag would remove the transform + continue cls._configure_input_size(cfg_transform, input_size) transform = cls._dispatch_transform(cfg_transform) transforms.append(transform) diff --git a/src/otx/recipe/_base_/data/classification.yaml b/src/otx/recipe/_base_/data/classification.yaml index e8ee41bf15e..8477f47c3e6 100644 --- a/src/otx/recipe/_base_/data/classification.yaml +++ b/src/otx/recipe/_base_/data/classification.yaml @@ -18,10 +18,20 @@ train_subset: - class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop init_args: scale: $(input_size) + - class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion + enable: false + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine + enable: false - class_path: otx.core.data.transform_libs.torchvision.RandomFlip init_args: prob: 0.5 is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.RandomVerticalFlip + enable: false + - class_path: torchvision.transforms.v2.GaussianBlur + enable: false + init_args: + kernel_size: 5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} @@ -30,6 +40,8 @@ train_subset: init_args: mean: [123.675, 116.28, 103.53] std: [58.395, 57.12, 57.375] + - class_path: torchvision.transforms.v2.GaussianNoise + enable: false sampler: class_path: otx.algo.samplers.balanced_sampler.BalancedSampler diff --git a/src/otx/recipe/classification/h_label_cls/efficientnet_b0.yaml b/src/otx/recipe/classification/h_label_cls/efficientnet_b0.yaml index d0ea7daec7b..4bfbe3fc121 100644 --- a/src/otx/recipe/classification/h_label_cls/efficientnet_b0.yaml +++ b/src/otx/recipe/classification/h_label_cls/efficientnet_b0.yaml @@ -43,10 +43,20 @@ overrides: - class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop init_args: scale: $(input_size) + - class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion + enable: false + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine + enable: false - class_path: otx.core.data.transform_libs.torchvision.RandomFlip init_args: prob: 0.5 is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.RandomVerticalFlip + enable: false + - class_path: torchvision.transforms.v2.GaussianBlur + enable: false + init_args: + kernel_size: 5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} @@ -55,3 +65,5 @@ overrides: init_args: mean: [123.675, 116.28, 103.53] std: [58.395, 57.12, 57.375] + - class_path: torchvision.transforms.v2.GaussianNoise + enable: false diff --git a/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml index 1dc6f209979..500cc168baa 100644 --- a/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml @@ -35,10 +35,20 @@ overrides: - class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop init_args: scale: $(input_size) + - class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion + enable: false + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine + enable: false - class_path: otx.core.data.transform_libs.torchvision.RandomFlip init_args: prob: 0.5 is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.RandomVerticalFlip + enable: false + - class_path: torchvision.transforms.v2.GaussianBlur + enable: false + init_args: + kernel_size: 5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} @@ -47,3 +57,5 @@ overrides: init_args: mean: [123.675, 116.28, 103.53] std: [58.395, 57.12, 57.375] + - class_path: torchvision.transforms.v2.GaussianNoise + enable: false diff --git a/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml b/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml index 872d28789ef..428fb89055b 100644 --- a/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml +++ b/src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml @@ -42,10 +42,20 @@ overrides: - class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop init_args: scale: $(input_size) + - class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion + enable: false + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine + enable: false - class_path: otx.core.data.transform_libs.torchvision.RandomFlip init_args: prob: 0.5 is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.RandomVerticalFlip + enable: false + - class_path: torchvision.transforms.v2.GaussianBlur + enable: false + init_args: + kernel_size: 5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} @@ -54,3 +64,5 @@ overrides: init_args: mean: [123.675, 116.28, 103.53] std: [58.395, 57.12, 57.375] + - class_path: torchvision.transforms.v2.GaussianNoise + enable: false diff --git a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml index 0dd6daf26f7..2454c0e7094 100644 --- a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml @@ -42,10 +42,20 @@ overrides: - class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop init_args: scale: $(input_size) + - class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion + enable: false + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine + enable: false - class_path: otx.core.data.transform_libs.torchvision.RandomFlip init_args: prob: 0.5 is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.RandomVerticalFlip + enable: false + - class_path: torchvision.transforms.v2.GaussianBlur + enable: false + init_args: + kernel_size: 5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} @@ -54,3 +64,5 @@ overrides: init_args: mean: [123.675, 116.28, 103.53] std: [58.395, 57.12, 57.375] + - class_path: torchvision.transforms.v2.GaussianNoise + enable: false diff --git a/src/otx/recipe/classification/multi_label_cls/efficientnet_b0.yaml b/src/otx/recipe/classification/multi_label_cls/efficientnet_b0.yaml index d2b11411a51..f3625158439 100644 --- a/src/otx/recipe/classification/multi_label_cls/efficientnet_b0.yaml +++ b/src/otx/recipe/classification/multi_label_cls/efficientnet_b0.yaml @@ -44,10 +44,20 @@ overrides: - class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop init_args: scale: $(input_size) + - class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion + enable: false + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine + enable: false - class_path: otx.core.data.transform_libs.torchvision.RandomFlip init_args: prob: 0.5 is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.RandomVerticalFlip + enable: false + - class_path: torchvision.transforms.v2.GaussianBlur + enable: false + init_args: + kernel_size: 5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} @@ -56,3 +66,5 @@ overrides: init_args: mean: [123.675, 116.28, 103.53] std: [58.395, 57.12, 57.375] + - class_path: torchvision.transforms.v2.GaussianNoise + enable: false diff --git a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml index 14f3b605f12..a304d76542b 100644 --- a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml @@ -48,10 +48,20 @@ overrides: - class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop init_args: scale: $(input_size) + - class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion + enable: false + - class_path: otx.core.data.transform_libs.torchvision.RandomAffine + enable: false - class_path: otx.core.data.transform_libs.torchvision.RandomFlip init_args: prob: 0.5 is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.RandomVerticalFlip + enable: false + - class_path: torchvision.transforms.v2.GaussianBlur + enable: false + init_args: + kernel_size: 5 - class_path: torchvision.transforms.v2.ToDtype init_args: dtype: ${as_torch_dtype:torch.float32} @@ -60,3 +70,5 @@ overrides: init_args: mean: [123.675, 116.28, 103.53] std: [58.395, 57.12, 57.375] + - class_path: torchvision.transforms.v2.GaussianNoise + enable: false diff --git a/tests/integration/api/test_augmentation.py b/tests/integration/api/test_augmentation.py new file mode 100644 index 00000000000..d55f5005cb9 --- /dev/null +++ b/tests/integration/api/test_augmentation.py @@ -0,0 +1,93 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import itertools + +import pytest +from datumaro import Dataset as DmDataset +from otx.core.config.data import SamplerConfig, SubsetConfig +from otx.core.data.factory import OTXDatasetFactory +from otx.core.data.mem_cache import MemCacheHandlerSingleton +from otx.core.types.task import OTXTaskType +from otx.engine.utils.auto_configurator import AutoConfigurator + + +def _test_augmentation( + recipe: str, + target_dataset_per_task: dict, + configurable_augs: list[str], +) -> None: + # Load recipe + recipe_tokens = recipe.split("/") + model_name = recipe_tokens[-1].split(".")[0] + task_name = recipe_tokens[-2] + task = OTXTaskType(task_name.upper()) + config = AutoConfigurator( + data_root=target_dataset_per_task[task_name], + task=task, + model_name=model_name, + ).config + train_config = config["data"]["train_subset"] + train_config["input_size"] = (32, 32) + + # Load dataset + dm_dataset = DmDataset.import_from( + target_dataset_per_task[task_name], + format=config["data"]["data_format"], + ) + mem_cache_handler = MemCacheHandlerSingleton.create( + mode="sinlgeprocessing", + mem_size=0, + ) + + # Evaluate all on/off aug combinations + img_shape = None + for switches in itertools.product([True, False], repeat=len(configurable_augs)): + # Configure on/off + for aug_name, switch in zip(configurable_augs, switches): + aug_found = False + for aug_config in train_config["transforms"]: + if aug_name in aug_config["class_path"]: + aug_config["enable"] = switch + aug_found = True + break + assert aug_found, f"{aug_name} not found in {recipe}" + # Create dataset + dataset = OTXDatasetFactory.create( + task=task, + dm_subset=dm_dataset, + cfg_subset=SubsetConfig(sampler=SamplerConfig(**train_config.pop("sampler", {})), **train_config), + mem_cache_handler=mem_cache_handler, + ) + + # Check if all aug combinations are size-compatible + data = dataset[0] + if not img_shape: + img_shape = data.img_info.img_shape + else: + assert img_shape == data.img_info.img_shape + + +CLS_RECIPES = [ + recipe for recipe in pytest.RECIPE_LIST if "_cls" in recipe and "semi" not in recipe and "tv_" not in recipe +] + + +@pytest.mark.parametrize( + "recipe", + CLS_RECIPES, +) +def test_augmentation_cls( + recipe: str, + fxt_target_dataset_per_task: dict, +): + configurable_augs = [ + "PhotoMetricDistortion", + "RandomAffine", + "RandomVerticalFlip", + "GaussianBlur", + "GaussianNoise", + ] + _test_augmentation(recipe, fxt_target_dataset_per_task, configurable_augs) diff --git a/tests/unit/core/data/test_transform_libs.py b/tests/unit/core/data/test_transform_libs.py index 9af540588d0..53440aa5d4b 100644 --- a/tests/unit/core/data/test_transform_libs.py +++ b/tests/unit/core/data/test_transform_libs.py @@ -3,6 +3,7 @@ # from __future__ import annotations +from copy import deepcopy from typing import Any import pytest @@ -119,7 +120,7 @@ def test_transform(self) -> None: class TestTorchVisionTransformLib: - @pytest.fixture(params=["from_dict", "from_list", "from_compose"]) + @pytest.fixture(params=["from_dict", "from_obj", "from_compose"]) def fxt_config(self, request) -> list[dict[str, Any]]: if request.param == "from_compose": return v2.Compose( @@ -186,6 +187,46 @@ def test_transform( item = dataset[0] assert isinstance(item, data_entity_cls) + def test_transform_enable_flag(self) -> None: + prefix = "torchvision.transforms.v2" + cfg_str = f""" + transforms: + - class_path: {prefix}.RandomResizedCrop + init_args: + size: [224, 224] + antialias: True + - class_path: {prefix}.RandomHorizontalFlip + init_args: + p: 0.5 + - class_path: {prefix}.ToDtype + init_args: + dtype: ${{as_torch_dtype:torch.float32}} + scale: True + - class_path: {prefix}.Normalize + init_args: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + """ + cfg_org = OmegaConf.create(cfg_str) + + cfg = deepcopy(cfg_org) + cfg.transforms[0].enable = False # Remove 1st + transform = TorchVisionTransformLib.generate(cfg) + assert len(transform.transforms) == 3 + assert "RandomResizedCrop" not in repr(transform) + + cfg = deepcopy(cfg_org) + cfg.transforms[1].enable = False # Remove 2nd + transform = TorchVisionTransformLib.generate(cfg) + assert len(transform.transforms) == 3 + assert "RandomHorizontalFlip" not in repr(transform) + + cfg = deepcopy(cfg_org) + cfg.transforms[2].enable = True # No effect + transform = TorchVisionTransformLib.generate(cfg) + assert len(transform.transforms) == 4 + assert "ToDtype" in repr(transform) + @pytest.fixture() def fxt_config_w_input_size(self) -> list[dict[str, Any]]: cfg = """