Skip to content

Commit

Permalink
Turn on/off classification augmentations (#4039)
Browse files Browse the repository at this point in the history
* Add 'enable' flag for transforms for on/off (default=True)

* Add common augmentations w/ disabled by default

* Add aug combination intg tests for cls

* Update change log

* Update test

* Fix pre-commit
  • Loading branch information
goodsong81 authored Oct 18, 2024
1 parent 78b560d commit b67715f
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 2 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.

## \[Unreleased\]

### New features

- Turn on/off classification augmentations
(https://github.com/openvinotoolkit/training_extensions/pull/4039)

### Enhancements

- Update visual prompting pipeline for multi-label zero-shot learning support
Expand Down
4 changes: 3 additions & 1 deletion src/otx/core/data/transform_libs/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ def forward(self, *_inputs: T_OTXDataEntity) -> T_OTXDataEntity | None:
inputs.image = img
inputs.img_info = _resize_image_info(inputs.img_info, img.shape[:2])

bboxes = inputs.bboxes
bboxes = getattr(inputs, "bboxes", [])
num_bboxes = len(bboxes)
if num_bboxes:
bboxes = project_bboxes(bboxes, warp_matrix)
Expand Down Expand Up @@ -3471,6 +3471,8 @@ def generate(cls, config: SubsetConfig) -> Compose:
transforms = []
for cfg_transform in config.transforms:
if isinstance(cfg_transform, (dict, DictConfig)):
if not cfg_transform.get("enable", True): # Optional "enable: false" flag would remove the transform
continue
cls._configure_input_size(cfg_transform, input_size)
transform = cls._dispatch_transform(cfg_transform)
transforms.append(transform)
Expand Down
12 changes: 12 additions & 0 deletions src/otx/recipe/_base_/data/classification.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@ train_subset:
- class_path: otx.core.data.transform_libs.torchvision.RandomResizedCrop
init_args:
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.RandomVerticalFlip
enable: false
- class_path: torchvision.transforms.v2.GaussianBlur
enable: false
init_args:
kernel_size: 5
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
Expand All @@ -30,6 +40,8 @@ train_subset:
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
- class_path: torchvision.transforms.v2.GaussianNoise
enable: false
sampler:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

Expand Down
12 changes: 12 additions & 0 deletions src/otx/recipe/classification/h_label_cls/efficientnet_b0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,20 @@ overrides:
- class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop
init_args:
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.RandomVerticalFlip
enable: false
- class_path: torchvision.transforms.v2.GaussianBlur
enable: false
init_args:
kernel_size: 5
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
Expand All @@ -55,3 +65,5 @@ overrides:
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
- class_path: torchvision.transforms.v2.GaussianNoise
enable: false
12 changes: 12 additions & 0 deletions src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,20 @@ overrides:
- class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop
init_args:
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.RandomVerticalFlip
enable: false
- class_path: torchvision.transforms.v2.GaussianBlur
enable: false
init_args:
kernel_size: 5
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
Expand All @@ -47,3 +57,5 @@ overrides:
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
- class_path: torchvision.transforms.v2.GaussianNoise
enable: false
12 changes: 12 additions & 0 deletions src/otx/recipe/classification/multi_class_cls/efficientnet_b0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,20 @@ overrides:
- class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop
init_args:
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.RandomVerticalFlip
enable: false
- class_path: torchvision.transforms.v2.GaussianBlur
enable: false
init_args:
kernel_size: 5
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
Expand All @@ -54,3 +64,5 @@ overrides:
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
- class_path: torchvision.transforms.v2.GaussianNoise
enable: false
12 changes: 12 additions & 0 deletions src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,20 @@ overrides:
- class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop
init_args:
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.RandomVerticalFlip
enable: false
- class_path: torchvision.transforms.v2.GaussianBlur
enable: false
init_args:
kernel_size: 5
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
Expand All @@ -54,3 +64,5 @@ overrides:
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
- class_path: torchvision.transforms.v2.GaussianNoise
enable: false
12 changes: 12 additions & 0 deletions src/otx/recipe/classification/multi_label_cls/efficientnet_b0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,20 @@ overrides:
- class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop
init_args:
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.RandomVerticalFlip
enable: false
- class_path: torchvision.transforms.v2.GaussianBlur
enable: false
init_args:
kernel_size: 5
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
Expand All @@ -56,3 +66,5 @@ overrides:
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
- class_path: torchvision.transforms.v2.GaussianNoise
enable: false
12 changes: 12 additions & 0 deletions src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,20 @@ overrides:
- class_path: otx.core.data.transform_libs.torchvision.EfficientNetRandomCrop
init_args:
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomAffine
enable: false
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.RandomVerticalFlip
enable: false
- class_path: torchvision.transforms.v2.GaussianBlur
enable: false
init_args:
kernel_size: 5
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
Expand All @@ -60,3 +70,5 @@ overrides:
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
- class_path: torchvision.transforms.v2.GaussianNoise
enable: false
93 changes: 93 additions & 0 deletions tests/integration/api/test_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import itertools

import pytest
from datumaro import Dataset as DmDataset
from otx.core.config.data import SamplerConfig, SubsetConfig
from otx.core.data.factory import OTXDatasetFactory
from otx.core.data.mem_cache import MemCacheHandlerSingleton
from otx.core.types.task import OTXTaskType
from otx.engine.utils.auto_configurator import AutoConfigurator


def _test_augmentation(
recipe: str,
target_dataset_per_task: dict,
configurable_augs: list[str],
) -> None:
# Load recipe
recipe_tokens = recipe.split("/")
model_name = recipe_tokens[-1].split(".")[0]
task_name = recipe_tokens[-2]
task = OTXTaskType(task_name.upper())
config = AutoConfigurator(
data_root=target_dataset_per_task[task_name],
task=task,
model_name=model_name,
).config
train_config = config["data"]["train_subset"]
train_config["input_size"] = (32, 32)

# Load dataset
dm_dataset = DmDataset.import_from(
target_dataset_per_task[task_name],
format=config["data"]["data_format"],
)
mem_cache_handler = MemCacheHandlerSingleton.create(
mode="sinlgeprocessing",
mem_size=0,
)

# Evaluate all on/off aug combinations
img_shape = None
for switches in itertools.product([True, False], repeat=len(configurable_augs)):
# Configure on/off
for aug_name, switch in zip(configurable_augs, switches):
aug_found = False
for aug_config in train_config["transforms"]:
if aug_name in aug_config["class_path"]:
aug_config["enable"] = switch
aug_found = True
break
assert aug_found, f"{aug_name} not found in {recipe}"
# Create dataset
dataset = OTXDatasetFactory.create(
task=task,
dm_subset=dm_dataset,
cfg_subset=SubsetConfig(sampler=SamplerConfig(**train_config.pop("sampler", {})), **train_config),
mem_cache_handler=mem_cache_handler,
)

# Check if all aug combinations are size-compatible
data = dataset[0]
if not img_shape:
img_shape = data.img_info.img_shape
else:
assert img_shape == data.img_info.img_shape


CLS_RECIPES = [
recipe for recipe in pytest.RECIPE_LIST if "_cls" in recipe and "semi" not in recipe and "tv_" not in recipe
]


@pytest.mark.parametrize(
"recipe",
CLS_RECIPES,
)
def test_augmentation_cls(
recipe: str,
fxt_target_dataset_per_task: dict,
):
configurable_augs = [
"PhotoMetricDistortion",
"RandomAffine",
"RandomVerticalFlip",
"GaussianBlur",
"GaussianNoise",
]
_test_augmentation(recipe, fxt_target_dataset_per_task, configurable_augs)
43 changes: 42 additions & 1 deletion tests/unit/core/data/test_transform_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from __future__ import annotations

from copy import deepcopy
from typing import Any

import pytest
Expand Down Expand Up @@ -119,7 +120,7 @@ def test_transform(self) -> None:


class TestTorchVisionTransformLib:
@pytest.fixture(params=["from_dict", "from_list", "from_compose"])
@pytest.fixture(params=["from_dict", "from_obj", "from_compose"])
def fxt_config(self, request) -> list[dict[str, Any]]:
if request.param == "from_compose":
return v2.Compose(
Expand Down Expand Up @@ -186,6 +187,46 @@ def test_transform(
item = dataset[0]
assert isinstance(item, data_entity_cls)

def test_transform_enable_flag(self) -> None:
prefix = "torchvision.transforms.v2"
cfg_str = f"""
transforms:
- class_path: {prefix}.RandomResizedCrop
init_args:
size: [224, 224]
antialias: True
- class_path: {prefix}.RandomHorizontalFlip
init_args:
p: 0.5
- class_path: {prefix}.ToDtype
init_args:
dtype: ${{as_torch_dtype:torch.float32}}
scale: True
- class_path: {prefix}.Normalize
init_args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
"""
cfg_org = OmegaConf.create(cfg_str)

cfg = deepcopy(cfg_org)
cfg.transforms[0].enable = False # Remove 1st
transform = TorchVisionTransformLib.generate(cfg)
assert len(transform.transforms) == 3
assert "RandomResizedCrop" not in repr(transform)

cfg = deepcopy(cfg_org)
cfg.transforms[1].enable = False # Remove 2nd
transform = TorchVisionTransformLib.generate(cfg)
assert len(transform.transforms) == 3
assert "RandomHorizontalFlip" not in repr(transform)

cfg = deepcopy(cfg_org)
cfg.transforms[2].enable = True # No effect
transform = TorchVisionTransformLib.generate(cfg)
assert len(transform.transforms) == 4
assert "ToDtype" in repr(transform)

@pytest.fixture()
def fxt_config_w_input_size(self) -> list[dict[str, Any]]:
cfg = """
Expand Down

0 comments on commit b67715f

Please sign in to comment.