Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove DataModuleConfig for the reduced config structure #3688

Merged
merged 15 commits into from
Jul 2, 2024
2 changes: 1 addition & 1 deletion src/otx/algo/utils/xai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _convert_labels_from_hcls_format(

def set_crop_padded_map_flag(explain_config: ExplainConfig, datamodule: OTXDataModule) -> ExplainConfig:
"""If resize with keep_ratio = True was used, set crop_padded_map flag to True."""
for transform in datamodule.config.test_subset.transforms:
for transform in datamodule.test_subset.transforms:
tranf_name = transform["class_path"].split(".")[-1]
if tranf_name == "Resize" and transform["init_args"].get("keep_ratio", False):
explain_config.crop_padded_map = True
Expand Down
4 changes: 2 additions & 2 deletions src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def engine_subcommand_parser(subcommand: str, **kwargs) -> tuple[ArgumentParser,
parser.link_arguments("work_dir", "workspace.work_dir")

parser.link_arguments("data_root", "engine.data_root")
parser.link_arguments("data_root", "data.config.data_root")
parser.link_arguments("engine.device", "data.config.device")
parser.link_arguments("data_root", "data.data_root")
parser.link_arguments("engine.device", "data.device")

added_arguments = parser.add_method_arguments(
Engine,
Expand Down
32 changes: 1 addition & 31 deletions src/otx/core/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@

from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any

from otx.core.types.device import DeviceType
from otx.core.types.image import ImageColorChannel
from otx.core.types.transformer_libs import TransformLibType


Expand Down Expand Up @@ -109,34 +107,6 @@ class UnlabeledDataConfig(SubsetConfig):
to_tv_image: bool = True


@dataclass
class DataModuleConfig:
"""DTO for data module configuration."""

data_format: str
data_root: str

train_subset: SubsetConfig
val_subset: SubsetConfig
test_subset: SubsetConfig
unlabeled_subset: UnlabeledDataConfig = field(default_factory=lambda: UnlabeledDataConfig())

tile_config: TileConfig = field(default_factory=lambda: TileConfig())
vpm_config: VisualPromptingConfig = field(default_factory=lambda: VisualPromptingConfig())

mem_cache_size: str = "1GB"
mem_cache_img_max_size: Optional[tuple[int, int]] = None
image_color_channel: ImageColorChannel = ImageColorChannel.RGB
stack_images: bool = True

include_polygons: bool = False
ignore_index: int = 255
unannotated_items_ratio: float = 0.0

auto_num_workers: bool = False
device: DeviceType = DeviceType.auto


@dataclass
class SamplerConfig:
"""Configuration class for defining the sampler used in the data loading process.
Expand Down
31 changes: 18 additions & 13 deletions src/otx/core/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from typing import TYPE_CHECKING

from otx.core.config.data import VisualPromptingConfig
from otx.core.types.image import ImageColorChannel
from otx.core.types.task import OTXTaskType
from otx.core.types.transformer_libs import TransformLibType

Expand All @@ -15,7 +17,7 @@
if TYPE_CHECKING:
from datumaro import Dataset as DmDataset

from otx.core.config.data import DataModuleConfig, SubsetConfig
from otx.core.config.data import SubsetConfig
from otx.core.data.mem_cache import MemCacheHandlerBase


Expand Down Expand Up @@ -69,19 +71,24 @@ def create( # ignore too many return statements
cls: type[OTXDatasetFactory],
task: OTXTaskType,
dm_subset: DmDataset,
mem_cache_handler: MemCacheHandlerBase,
cfg_subset: SubsetConfig,
cfg_data_module: DataModuleConfig,
mem_cache_handler: MemCacheHandlerBase,
mem_cache_img_max_size: tuple[int, int] | None = None,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
include_polygons: bool = False,
ignore_index: int = 255,
vpm_config: VisualPromptingConfig = VisualPromptingConfig(), # noqa: B008
) -> OTXDataset:
"""Create OTXDataset."""
transforms = TransformLibFactory.generate(cfg_subset)
common_kwargs = {
"dm_subset": dm_subset,
"transforms": transforms,
"mem_cache_handler": mem_cache_handler,
"mem_cache_img_max_size": cfg_data_module.mem_cache_img_max_size,
"image_color_channel": cfg_data_module.image_color_channel,
"stack_images": cfg_data_module.stack_images,
"mem_cache_img_max_size": mem_cache_img_max_size,
"image_color_channel": image_color_channel,
"stack_images": stack_images,
"to_tv_image": cfg_subset.to_tv_image,
}

Expand Down Expand Up @@ -117,14 +124,12 @@ def create( # ignore too many return statements
if task in [OTXTaskType.ROTATED_DETECTION, OTXTaskType.INSTANCE_SEGMENTATION]:
from .dataset.instance_segmentation import OTXInstanceSegDataset

# NOTE: DataModuleConfig does not have include_polygons attribute
include_polygons = getattr(cfg_data_module, "include_polygons", False)
return OTXInstanceSegDataset(include_polygons=include_polygons, **common_kwargs)

if task == OTXTaskType.SEMANTIC_SEGMENTATION:
from .dataset.segmentation import OTXSegmentationDataset

return OTXSegmentationDataset(**common_kwargs, ignore_index=cfg_data_module.ignore_index)
return OTXSegmentationDataset(ignore_index=ignore_index, **common_kwargs)

if task == OTXTaskType.ACTION_CLASSIFICATION:
from .dataset.action_classification import OTXActionClsDataset
Expand All @@ -134,15 +139,15 @@ def create( # ignore too many return statements
if task == OTXTaskType.VISUAL_PROMPTING:
from .dataset.visual_prompting import OTXVisualPromptingDataset

use_bbox = getattr(cfg_data_module.vpm_config, "use_bbox", False)
use_point = getattr(cfg_data_module.vpm_config, "use_point", False)
use_bbox = getattr(vpm_config, "use_bbox", False)
use_point = getattr(vpm_config, "use_point", False)
return OTXVisualPromptingDataset(use_bbox=use_bbox, use_point=use_point, **common_kwargs)

if task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING:
from .dataset.visual_prompting import OTXZeroShotVisualPromptingDataset

use_bbox = getattr(cfg_data_module.vpm_config, "use_bbox", False)
use_point = getattr(cfg_data_module.vpm_config, "use_point", False)
use_bbox = getattr(vpm_config, "use_bbox", False)
use_point = getattr(vpm_config, "use_point", False)
return OTXZeroShotVisualPromptingDataset(use_bbox=use_bbox, use_point=use_point, **common_kwargs)

raise NotImplementedError(task)
8 changes: 7 additions & 1 deletion src/otx/core/data/mem_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ class MemCacheHandlerBase:
It will be combined with LoadImageFromOTXDataset to store/retrieve the samples in memory.
"""

def __init__(self, mem_size: int):
def __init__(self, mem_size: int, mem_cache_img_max_size: tuple[int, int] | None = None):
self._mem_size = mem_size
self._mem_cache_img_max_size = mem_cache_img_max_size
self._init_data_structs(mem_size)

def _init_data_structs(self, mem_size: int) -> None:
Expand All @@ -108,6 +109,11 @@ def mem_size(self) -> int:
"""Get the reserved memory pool size (bytes)."""
return len(self._arr)

@property
def mem_cache_img_max_size(self) -> tuple[int, int] | None:
"""Get the image max size in mem cache."""
return self._mem_cache_img_max_size

def get(self, key: Any) -> tuple[np.ndarray | None, dict | None]: # noqa: ANN401
"""Try to look up the cached item with the given key.

Expand Down
Loading
Loading