From 7a201525c5c36595307fa7862f029019a9cf89cd Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Mon, 13 Feb 2023 15:23:13 +0900 Subject: [PATCH 01/15] Add MemCacheHandler Signed-off-by: Kim, Vinnam --- .../datasets/pipelines/caching/__init__.py | 8 + .../load_image_from_file_with_cache.py | 82 +++++++++ .../pipelines/caching/mem_cache_handler.py | 152 ++++++++++++++++ tests/unit/mpa/test_caching.py | 165 ++++++++++++++++++ 4 files changed, 407 insertions(+) create mode 100644 otx/mpa/modules/datasets/pipelines/caching/__init__.py create mode 100644 otx/mpa/modules/datasets/pipelines/caching/load_image_from_file_with_cache.py create mode 100644 otx/mpa/modules/datasets/pipelines/caching/mem_cache_handler.py create mode 100644 tests/unit/mpa/test_caching.py diff --git a/otx/mpa/modules/datasets/pipelines/caching/__init__.py b/otx/mpa/modules/datasets/pipelines/caching/__init__.py new file mode 100644 index 00000000000..e0978a47373 --- /dev/null +++ b/otx/mpa/modules/datasets/pipelines/caching/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from .load_image_from_file_with_cache import LoadImageFromFileWithCache +from .mem_cache_handler import MemCacheHandler + +__all__ = ["MemCacheHandler", "LoadImageFromFileWithCache"] diff --git a/otx/mpa/modules/datasets/pipelines/caching/load_image_from_file_with_cache.py b/otx/mpa/modules/datasets/pipelines/caching/load_image_from_file_with_cache.py new file mode 100644 index 00000000000..2d062b7945d --- /dev/null +++ b/otx/mpa/modules/datasets/pipelines/caching/load_image_from_file_with_cache.py @@ -0,0 +1,82 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import os.path as osp + +import mmcv +import numpy as np +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines import LoadImageFromFile + +from .mem_cache_handler import MemCacheHandler + + +@PIPELINES.register_module() +class LoadImageFromFileWithCache(LoadImageFromFile): + """Load an image from file. + + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename"). Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:`mmcv.imfrombytes`. + Defaults to 'color'. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__(self, to_float32=False, color_type="color", channel_order="bgr"): + self.to_float32 = to_float32 + self.color_type = color_type + self.channel_order = channel_order + self.file_client = mmcv.FileClient(backend="disk") + self.mem_cache_handler = MemCacheHandler() + + def __call__(self, results): + """Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`mmdet.CustomDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + if results["img_prefix"] is not None: + filename = osp.join(results["img_prefix"], results["img_info"]["filename"]) + else: + filename = results["img_info"]["filename"] + + img = self.mem_cache_handler.get(key=filename) + + if img is None: + img_bytes = self.file_client.get(filename) + img = mmcv.imfrombytes(img_bytes, flag=self.color_type, channel_order=self.channel_order) + self.mem_cache_handler.put(key=filename, data=img) + + if self.to_float32: + img = img.astype(np.float32) + + results["filename"] = filename + results["ori_filename"] = results["img_info"]["filename"] + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape + results["img_fields"] = ["img"] + return results + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"to_float32={self.to_float32}, " + f"color_type='{self.color_type}', " + f"channel_order='{self.channel_order}', " + f"mem_cache_handler={self.mem_cache_handler})" + ) diff --git a/otx/mpa/modules/datasets/pipelines/caching/mem_cache_handler.py b/otx/mpa/modules/datasets/pipelines/caching/mem_cache_handler.py new file mode 100644 index 00000000000..16cffb6d843 --- /dev/null +++ b/otx/mpa/modules/datasets/pipelines/caching/mem_cache_handler.py @@ -0,0 +1,152 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import ctypes as ct +import multiprocessing as mp +import re +from typing import Optional + +import numpy as np + +from otx.mpa.utils.logger import get_logger + +logger = get_logger() + + +class _DummyLock: + def __enter__(self, *args, **kwargs): + pass + + def __exit__(self, *args, **kwargs): + pass + + +class MemCacheHandlerForSP: + def __init__(self, mem_size: int): + self._init_data_structs(mem_size) + + def _init_data_structs(self, mem_size: int): + self.arr = (ct.c_uint8 * mem_size)() + self.cur_page = ct.c_size_t(0) + self.cache_addr = {} + self.lock = _DummyLock() + + def __len__(self): + return len(self.cache_addr) + + @property + def mem_size(self) -> int: + return len(self.arr) + + def get(self, key: str) -> Optional[np.ndarray]: + if key not in self.cache_addr: + return None + + addr = self.cache_addr[key] + + offset, count, shape, strides = addr + + data = np.frombuffer(self.arr, dtype=np.uint8, count=count, offset=offset) + return np.lib.stride_tricks.as_strided(data, shape, strides) + + def put(self, key: str, data: np.ndarray) -> Optional[int]: + assert data.dtype == np.uint8 + + with self.lock: + new_page = self.cur_page.value + data.size + + if key in self.cache_addr or new_page > self.mem_size: + return None + + offset = ct.byref(self.arr, self.cur_page.value) + ct.memmove(offset, data.ctypes.data, data.size) + + self.cache_addr[key] = ( + self.cur_page.value, + data.size, + data.shape, + data.strides, + ) + self.cur_page.value = new_page + return new_page + + def __repr__(self): + return ( + f"{self.__class__.__name__} " + f"uses {self.cur_page.value} / {self.mem_size} memory pool and " + f"store {len(self)} items." + ) + + +class MemCacheHandlerForMP(MemCacheHandlerForSP): + def __init__(self, mem_size: int): + super().__init__(mem_size) + + def _init_data_structs(self, mem_size: int): + self.arr = mp.Array(ct.c_uint8, mem_size, lock=False) + self.cur_page = mp.Value(ct.c_size_t, 0, lock=False) + + self.manager = mp.Manager() + self.cache_addr = self.manager.dict() + self.lock = mp.Lock() + + def __del__(self): + self.manager.shutdown() + + +class MemCacheHandler(MemCacheHandlerForSP): + instance = Optional[MemCacheHandlerForSP] + + def __init__(self): + pass + + def __new__(cls) -> Optional[MemCacheHandlerForSP]: + if not hasattr(cls, "instance"): + raise RuntimeError(f"Before calling {cls.__name__}(), you should call {cls.__name__}.create() first.") + + return cls.instance + + @classmethod + def create(cls, mode: str, mem_size: str) -> Optional[MemCacheHandlerForSP]: + mem_size = cls._parse_mem_size_str(mem_size) + logger.info(f"Try to create a {mem_size} size memory pool.") + + if mode == "multiprocessing": + cls.instance = MemCacheHandlerForMP(mem_size) + elif mode == "singleprocessing": + cls.instance = MemCacheHandlerForSP(mem_size) + else: + raise ValueError(f"{mode} is unknown mode.") + + return cls.instance + + @staticmethod + def _parse_mem_size_str(mem_size: str) -> int: + assert isinstance(mem_size, str) + + m = re.match(r"^([\d\.]+)\s*([a-zA-Z]{0,3})$", mem_size.strip()) + + if m is None: + raise ValueError(f"Cannot parse {mem_size} string.") + + units = { + "": 1, + "B": 1, + "KB": 2**10, + "MB": 2**20, + "GB": 2**30, + "KIB": 10**3, + "MIB": 10**6, + "GIB": 10**9, + "K": 2**10, + "M": 2**20, + "G": 2**30, + } + + number, unit = int(m.group(1)), m.group(2).upper() + + if unit not in units: + raise ValueError(f"{mem_size} has disallowed unit ({unit}).") + + return number * units[unit] diff --git a/tests/unit/mpa/test_caching.py b/tests/unit/mpa/test_caching.py new file mode 100644 index 00000000000..ab8e202dde9 --- /dev/null +++ b/tests/unit/mpa/test_caching.py @@ -0,0 +1,165 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import os.path as osp +import string +from tempfile import TemporaryDirectory + +import cv2 +import numpy as np +import pytest +from torch.utils.data import DataLoader, Dataset + +from otx.mpa.modules.datasets.pipelines.caching import ( + LoadImageFromFileWithCache, + MemCacheHandler, +) + + +@pytest.fixture +def fxt_data_list(): + np.random.seed(3003) + + num_data = 10 + h = w = key_len = 16 + + data_list = [] + for _ in range(num_data): + data = np.random.randint(0, 256, size=[h, w, 3], dtype=np.uint8) + key = "".join( + [string.ascii_lowercase[i] for i in np.random.randint(0, len(string.ascii_lowercase), size=[key_len])] + ) + data_list += [(key, data)] + + return data_list + + +@pytest.fixture +def fxt_caching_dataset_cls(fxt_data_list): + with TemporaryDirectory() as img_prefix: + for key, data in fxt_data_list: + cv2.imwrite(osp.join(img_prefix, key + ".png"), data) + + class CachingDataset(Dataset): + def __init__(self) -> None: + super().__init__() + self.data_list = fxt_data_list + self.load = LoadImageFromFileWithCache() + self.file_get_count = 0 + + __get = self.load.file_client.get + + def _get(filepath): + self.file_get_count += 1 + return __get(filepath) + + self.load.file_client.get = _get + + def reset_file_count(self): + self.file_get_count = 0 + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + key, _ = self.data_list[index] + results = { + "img_prefix": img_prefix, + "img_info": {"filename": key + ".png"}, + } + return self.load(results) + + yield CachingDataset + + +def get_data_list_size(data_list): + size = 0 + for _, data in data_list: + size += data.size + return size + + +class TestMemCacheHandler: + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + def test_fully_caching(self, mode, fxt_data_list): + mem_size = str(get_data_list_size(fxt_data_list)) + MemCacheHandler.create(mode, mem_size) + handler = MemCacheHandler() + + for key, data in fxt_data_list: + assert handler.put(key, data) > 0 + + for key, data in fxt_data_list: + get_data = handler.get(key) + + assert np.array_equal(get_data, data) + + # Fully cached + assert len(handler) == len(fxt_data_list) + + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + def test_unfully_caching(self, mode, fxt_data_list): + mem_size = str(get_data_list_size(fxt_data_list) // 2) + MemCacheHandler.create(mode, mem_size) + handler = MemCacheHandler() + + for idx, (key, data) in enumerate(fxt_data_list): + if idx < len(fxt_data_list) // 2: + assert handler.put(key, data) > 0 + else: + assert handler.put(key, data) is None + + for idx, (key, data) in enumerate(fxt_data_list): + get_data = handler.get(key) + + if idx < len(fxt_data_list) // 2: + assert np.array_equal(get_data, data) + else: + assert get_data is None + + # Unfully (half) cached + assert len(handler) == len(fxt_data_list) // 2 + + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + @pytest.mark.parametrize( + "mem_size,expected", + [ + ("1561", 1561), + ("121k", 121 * (2**10)), + ("121kb", 121 * (2**10)), + ("121kib", 121 * (10**3)), + ("121as", None), + ("121dddd", None), + ], + ) + def test_mem_size_parsing(self, mode, mem_size, expected): + try: + MemCacheHandler.create(mode, mem_size) + handler = MemCacheHandler() + assert handler.mem_size == expected + except ValueError: + assert expected is None + + +class TestLoadImageFromFileWithCache: + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + def test_combine_with_dataloader(self, mode, fxt_caching_dataset_cls, fxt_data_list): + mem_size = str(get_data_list_size(fxt_data_list)) + MemCacheHandler.create(mode, mem_size) + + dataset = fxt_caching_dataset_cls() + + for _ in DataLoader(dataset): + continue + + # This initial round requires file_client.get() for all data samples. + assert dataset.file_get_count == len(dataset) + + dataset.reset_file_count() + + for _ in DataLoader(dataset): + continue + + # The second round requires no file_client.get(). + assert dataset.file_get_count == 0 From d78f00c0b18a20341252d88970b8cf967796486a Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 14 Feb 2023 12:23:35 +0900 Subject: [PATCH 02/15] Attach MemCacheHandler to otx.train - Did some refactoring - Add docstrings Signed-off-by: Kim, Vinnam --- .../adapters/mmcls/data/pipelines.py | 50 +---- .../classification/configs/configuration.yaml | 16 ++ .../common/configs/training_base.py | 10 + otx/algorithms/common/tasks/training_base.py | 29 ++- .../common/tools/caching/__init__.py | 9 + .../common/tools/caching/mem_cache_handler.py | 192 ++++++++++++++++++ .../common/tools/caching/mem_cache_hook.py | 29 +++ .../adapters/mmdet/data/pipelines.py | 50 +---- .../adapters/mmseg/data/pipelines.py | 48 +---- otx/cli/manager/config_manager.py | 13 +- otx/cli/tools/train.py | 11 + otx/cli/utils/parser.py | 44 ++++ otx/core/data/pipelines/__init__.py | 3 + .../pipelines/load_image_from_otx_dataset.py | 72 +++++++ .../datasets/pipelines/caching/__init__.py | 8 - .../load_image_from_file_with_cache.py | 82 -------- .../pipelines/caching/mem_cache_handler.py | 152 -------------- otx/mpa/modules/hooks/eval_hook.py | 1 + 18 files changed, 437 insertions(+), 382 deletions(-) create mode 100644 otx/algorithms/common/tools/caching/__init__.py create mode 100644 otx/algorithms/common/tools/caching/mem_cache_handler.py create mode 100644 otx/algorithms/common/tools/caching/mem_cache_hook.py create mode 100644 otx/core/data/pipelines/__init__.py create mode 100644 otx/core/data/pipelines/load_image_from_otx_dataset.py delete mode 100644 otx/mpa/modules/datasets/pipelines/caching/__init__.py delete mode 100644 otx/mpa/modules/datasets/pipelines/caching/load_image_from_file_with_cache.py delete mode 100644 otx/mpa/modules/datasets/pipelines/caching/mem_cache_handler.py diff --git a/otx/algorithms/classification/adapters/mmcls/data/pipelines.py b/otx/algorithms/classification/adapters/mmcls/data/pipelines.py index a13ede8a9bb..23543ddbc58 100644 --- a/otx/algorithms/classification/adapters/mmcls/data/pipelines.py +++ b/otx/algorithms/classification/adapters/mmcls/data/pipelines.py @@ -3,7 +3,6 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import copy -import tempfile from typing import Any, Dict, List import numpy as np @@ -13,59 +12,16 @@ from PIL import Image, ImageFilter from torchvision import transforms as T -from otx.algorithms.common.utils.data import get_image +import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base from otx.api.utils.argument_checks import check_input_parameters_type -_CACHE_DIR = tempfile.TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with - # TODO: refactoring to common modules # TODO: refactoring to Sphinx style. @PIPELINES.register_module() -class LoadImageFromOTXDataset: - """Pipeline element that loads an image from a OTX Dataset on the fly. - - Can do conversion to float 32 if needed. - Expected entries in the 'results' dict that should be passed to this pipeline element are: - results['dataset_item']: dataset_item from which to load the image - results['dataset_id']: id of the dataset to which the item belongs - results['index']: index of the item in the dataset - - :param to_float32: optional bool, True to convert images to fp32. defaults to False - """ - - @check_input_parameters_type() - def __init__(self, to_float32: bool = False): - self.to_float32 = to_float32 - - @check_input_parameters_type() - def __call__(self, results: Dict[str, Any]): - """Callback function of LoadImageFromOTXDataset.""" - # Get image (possibly from cache) - img = get_image(results, _CACHE_DIR.name, to_float32=self.to_float32) - shape = img.shape - - assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}" - assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}" - - filename = f"Dataset item index {results['index']}" - results["filename"] = filename - results["ori_filename"] = filename - results["img"] = img - results["img_shape"] = shape - results["ori_shape"] = shape - # Set initial values for default meta_keys - results["pad_shape"] = shape - num_channels = 1 if len(shape) < 3 else shape[2] - results["img_norm_cfg"] = dict( - mean=np.zeros(num_channels, dtype=np.float32), - std=np.ones(num_channels, dtype=np.float32), - to_rgb=False, - ) - results["img_fields"] = ["img"] - - return results +class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset): + """Pipeline element that loads an image from a OTX Dataset on the fly.""" @PIPELINES.register_module() diff --git a/otx/algorithms/classification/configs/configuration.yaml b/otx/algorithms/classification/configs/configuration.yaml index 541a4ab3528..ef6585ab73e 100644 --- a/otx/algorithms/classification/configs/configuration.yaml +++ b/otx/algorithms/classification/configs/configuration.yaml @@ -354,5 +354,21 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster + editable: true + header: Size of memory pool for caching decoded data to load data faster + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true diff --git a/otx/algorithms/common/configs/training_base.py b/otx/algorithms/common/configs/training_base.py index c8e17af0392..1e99f5048ee 100644 --- a/otx/algorithms/common/configs/training_base.py +++ b/otx/algorithms/common/configs/training_base.py @@ -282,6 +282,16 @@ class BaseAlgoBackendParameters(ParameterGroup): visible_in_ui=True, ) + mem_cache_size = configurable_integer( + header="Size of memory pool for caching decoded data to load data faster", + description="Size of memory pool for caching decoded data to load data faster", + default_value=0, + min_value=0, + max_value=maxsize, + visible_in_ui=False, + affects_outcome_of=ModelLifecycle.TRAINING, + ) + @attrs class BaseTilingParameters(ParameterGroup): """BaseTilingParameters for OTX Algorithms.""" diff --git a/otx/algorithms/common/tasks/training_base.py b/otx/algorithms/common/tasks/training_base.py index 5ab1f6289da..a624eeefbe7 100644 --- a/otx/algorithms/common/tasks/training_base.py +++ b/otx/algorithms/common/tasks/training_base.py @@ -25,6 +25,7 @@ import numpy as np import torch +from mmcv.runner import get_dist_info from mmcv.utils.config import Config, ConfigDict from otx.algorithms.common.adapters.mmcv.hooks import OTXLoggerHook @@ -33,6 +34,7 @@ get_configs_by_pairs, ) from otx.algorithms.common.configs import TrainType +from otx.algorithms.common.tools import caching from otx.algorithms.common.utils import UncopiableDefaultDict from otx.api.entities.datasets import DatasetEntity from otx.api.entities.label import LabelEntity @@ -323,6 +325,9 @@ def _initialize(self, options=None): # noqa: C901 dataloader_cfg["persistent_workers"] = False data_cfg[f"{subset}_dataloader"] = dataloader_cfg + # Update recipe with caching modules + self._update_caching_modules(data_cfg) + if self._data_cfg is not None: align_data_config_with_recipe(self._data_cfg, self._recipe_cfg) @@ -403,7 +408,6 @@ def _init_deploy_cfg(self) -> Union[Config, None]: deploy_cfg = MPAConfig.fromfile(deploy_cfg_path) def patch_input_preprocessing(deploy_cfg): - normalize_cfg = get_configs_by_pairs( self._recipe_cfg.data.test.pipeline, dict(type="Normalize"), @@ -611,3 +615,26 @@ def set_early_stopping_hook(self): update_or_add_custom_hook(self._recipe_cfg, early_stop_hook) else: remove_custom_hook(self._recipe_cfg, "LazyEarlyStoppingHook") + + def _update_caching_modules(self, data_cfg: Config) -> None: + def _find_max_num_workers(cfg: dict): + num_workers = [0] + for key, value in cfg.items(): + if key == "workers_per_gpu" and isinstance(value, int): + num_workers += [value] + elif isinstance(value, dict): + num_workers += [_find_max_num_workers(value)] + + return max(num_workers) + + _, world_size = get_dist_info() + mem_cache_size = self.hyperparams.algo_backend.mem_cache_size // world_size + max_num_workers = _find_max_num_workers(data_cfg) + + mode = "multiprocessing" if max_num_workers > 0 else "singleprocessing" + caching.MemCacheHandlerSingleton.create(mode, mem_cache_size) + + update_or_add_custom_hook( + self._recipe_cfg, + ConfigDict(type="MemCacheHook"), + ) diff --git a/otx/algorithms/common/tools/caching/__init__.py b/otx/algorithms/common/tools/caching/__init__.py new file mode 100644 index 00000000000..64232048d44 --- /dev/null +++ b/otx/algorithms/common/tools/caching/__init__.py @@ -0,0 +1,9 @@ +"""Module for data caching.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from .mem_cache_handler import MemCacheHandlerSingleton +from .mem_cache_hook import MemCacheHook + +__all__ = ["MemCacheHandlerSingleton", "MemCacheHook"] diff --git a/otx/algorithms/common/tools/caching/mem_cache_handler.py b/otx/algorithms/common/tools/caching/mem_cache_handler.py new file mode 100644 index 00000000000..1cfd658df61 --- /dev/null +++ b/otx/algorithms/common/tools/caching/mem_cache_handler.py @@ -0,0 +1,192 @@ +"""Memory cache handler implementations and singleton class to call them.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import ctypes as ct +import multiprocessing as mp +from multiprocessing.managers import DictProxy +from typing import Any, Dict, Optional, Union + +import numpy as np +from multiprocess.synchronize import Lock + +from otx.mpa.utils.logger import get_logger + +logger = get_logger() + + +class _DummyLock: + def __enter__(self, *args, **kwargs): + pass + + def __exit__(self, *args, **kwargs): + pass + + +class MemCacheHandlerBase: + """Base class for memory cache handler. + + It will be combined with LoadImageFromOTXDataset to store/retrieve the samples in memory. + """ + + def __init__(self, mem_size: int): + self._init_data_structs(mem_size) + + def _init_data_structs(self, mem_size: int): + self._arr = (ct.c_uint8 * mem_size)() + self._cur_page = ct.c_size_t(0) + self._cache_addr: Union[Dict, DictProxy] = {} + self._lock: Union[Lock, _DummyLock] = _DummyLock() + self._freeze = ct.c_bool(False) + + def __len__(self): + """Get the number of cached items.""" + return len(self._cache_addr) + + @property + def mem_size(self) -> int: + """Get the reserved memory pool size (bytes).""" + return len(self._arr) + + def get(self, key: Any) -> Optional[np.ndarray]: + """Try to look up the cached item with the given key. + + Args: + key (Any): A key for looking up the cached item + + Returns: + If succeed return np.ndarray, otherwise return None + """ + if key not in self._cache_addr: + return None + + addr = self._cache_addr[key] + + offset, count, shape, strides = addr + + data = np.frombuffer(self._arr, dtype=np.uint8, count=count, offset=offset) + return np.lib.stride_tricks.as_strided(data, shape, strides) + + def put(self, key: Any, data: np.ndarray) -> Optional[int]: + """Try to store np.ndarray with a key to the reserved memory pool. + + Args: + key (Any): A key to store the cached item + data (np.ndarray): A data sample to store + + Returns: + Optional[int]: If succeed return the address of cached item in memory pool + """ + if self._freeze.value: + return None + + assert data.dtype == np.uint8 + + with self._lock: + new_page = self._cur_page.value + data.size + + if key in self._cache_addr or new_page > self.mem_size: + return None + + offset = ct.byref(self._arr, self._cur_page.value) + ct.memmove(offset, data.ctypes.data, data.size) + + self._cache_addr[key] = ( + self._cur_page.value, + data.size, + data.shape, + data.strides, + ) + self._cur_page.value = new_page + return new_page + + def __repr__(self): + """Representation for the current handler status.""" + perc = 100.0 * self._cur_page.value / self.mem_size + return ( + f"{self.__class__.__name__} " + f"uses {self._cur_page.value} / {self.mem_size} ({perc:.1f}%) memory pool and " + f"store {len(self)} items." + ) + + def freeze(self): + """If frozen, it is impossible to store a new item anymore.""" + self._freeze.value = True + + def unfreeze(self): + """If unfrozen, it is possible to store a new item.""" + self._freeze.value = False + + +class MemCacheHandlerForSP(MemCacheHandlerBase): + """Memory caching handler for single processing. + + Use if PyTorch's DataLoader.num_workers == 0. + """ + + +class MemCacheHandlerForMP(MemCacheHandlerBase): + """Memory caching handler for multi processing. + + Use if PyTorch's DataLoader.num_workers > 0. + """ + + def _init_data_structs(self, mem_size: int): + self._arr = mp.Array(ct.c_uint8, mem_size, lock=False) + self._cur_page = mp.Value(ct.c_size_t, 0, lock=False) + + self._manager = mp.Manager() + self._cache_addr: DictProxy = self._manager.dict() + self._lock = mp.Lock() + self._freeze = mp.Value(ct.c_bool, False, lock=False) + + def __del__(self): + """When deleting, manager should also be shutdowned.""" + self._manager.shutdown() + + +class MemCacheHandlerSingleton: + """A singleton class to create, delete and get MemCacheHandlerBase.""" + + instance: MemCacheHandlerBase + + @classmethod + def get(cls) -> MemCacheHandlerBase: + """Get the created MemCacheHandlerBase. + + If no one is created before, raise RuntimeError. + """ + if not hasattr(cls, "instance"): + cls_name = cls.__class__.__name__ + raise RuntimeError(f"Before calling {cls_name}.get(), you should call {cls_name}.create() first.") + + return cls.instance + + @classmethod + def create(cls, mode: str, mem_size: int) -> MemCacheHandlerBase: + """Create a new MemCacheHandlerBase instance. + + Args: + mode (str): There are two options: multiprocessing or singleprocessing. + mem_size (int): The size of memory pool (bytes). + """ + logger.info(f"Try to create a {mem_size} size memory pool.") + + if mem_size == 0: + cls.instance = MemCacheHandlerBase(mem_size) + cls.instance.freeze() + elif mode == "multiprocessing": + cls.instance = MemCacheHandlerForMP(mem_size) + elif mode == "singleprocessing": + cls.instance = MemCacheHandlerForSP(mem_size) + else: + raise ValueError(f"{mode} is unknown mode.") + + return cls.instance + + @classmethod + def delete(cls) -> None: + """Delete the existing MemCacheHandlerBase instance.""" + if hasattr(cls, "instance"): + del cls.instance diff --git a/otx/algorithms/common/tools/caching/mem_cache_hook.py b/otx/algorithms/common/tools/caching/mem_cache_hook.py new file mode 100644 index 00000000000..1dc1462ae6d --- /dev/null +++ b/otx/algorithms/common/tools/caching/mem_cache_hook.py @@ -0,0 +1,29 @@ +"""Memory cache hook for logging and freezing MemCacheHandler.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from mmcv.runner.hooks import HOOKS, Hook + +from .mem_cache_handler import MemCacheHandlerSingleton + + +@HOOKS.register_module() +class MemCacheHook(Hook): + """Memory cache hook for logging and freezing MemCacheHandler.""" + + def __init__(self) -> None: + self.handler = MemCacheHandlerSingleton.get() + + def before_run(self, runner): + """Before run, freeze the handler.""" + self.handler.freeze() + + def before_epoch(self, runner): + """Before training, unfreeze the handler.""" + self.handler.unfreeze() + + def after_epoch(self, runner): + """After epoch. Log the handler statistics.""" + self.handler.freeze() + runner.logger.info(f"{self.handler}") diff --git a/otx/algorithms/detection/adapters/mmdet/data/pipelines.py b/otx/algorithms/detection/adapters/mmdet/data/pipelines.py index 836282067c1..0b2850c054c 100644 --- a/otx/algorithms/detection/adapters/mmdet/data/pipelines.py +++ b/otx/algorithms/detection/adapters/mmdet/data/pipelines.py @@ -13,65 +13,21 @@ # See the License for the specific language governing permissions # and limitations under the License. import copy -import tempfile from typing import Any, Dict, Optional -import numpy as np from mmdet.datasets.builder import PIPELINES -from otx.algorithms.common.utils.data import get_image +import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base from otx.api.entities.label import Domain from otx.api.utils.argument_checks import check_input_parameters_type from .dataset import get_annotation_mmdet_format -_CACHE_DIR = tempfile.TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with - # pylint: disable=too-many-instance-attributes, too-many-arguments @PIPELINES.register_module() -class LoadImageFromOTXDataset: - """Pipeline element that loads an image from a OTX Dataset on the fly. Can do conversion to float 32 if needed. - - Expected entries in the 'results' dict that should be passed to this pipeline element are: - results['dataset_item']: dataset_item from which to load the image - results['dataset_id']: id of the dataset to which the item belongs - results['index']: index of the item in the dataset - - :param to_float32: optional bool, True to convert images to fp32. defaults to False - """ - - @check_input_parameters_type() - def __init__(self, to_float32: bool = False): - self.to_float32 = to_float32 - - @check_input_parameters_type() - def __call__(self, results: Dict[str, Any]): - """Callback function LoadImageFromOTXDataset.""" - # Get image (possibly from cache) - img = get_image(results, _CACHE_DIR.name, to_float32=self.to_float32) - shape = img.shape - - assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}" - assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}" - - filename = f"Dataset item index {results['index']}" - results["filename"] = filename - results["ori_filename"] = filename - results["img"] = img - results["img_shape"] = shape - results["ori_shape"] = shape - # Set initial values for default meta_keys - results["pad_shape"] = shape - num_channels = 1 if len(shape) < 3 else shape[2] - results["img_norm_cfg"] = dict( - mean=np.zeros(num_channels, dtype=np.float32), - std=np.ones(num_channels, dtype=np.float32), - to_rgb=False, - ) - results["img_fields"] = ["img"] - - return results +class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset): + """Pipeline element that loads an image from a OTX Dataset on the fly.""" @PIPELINES.register_module() diff --git a/otx/algorithms/segmentation/adapters/mmseg/data/pipelines.py b/otx/algorithms/segmentation/adapters/mmseg/data/pipelines.py index e1cfdc4f2fa..9d0f0278954 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/data/pipelines.py +++ b/otx/algorithms/segmentation/adapters/mmseg/data/pipelines.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. -import tempfile from copy import deepcopy from typing import Any, Dict, List @@ -24,55 +23,16 @@ from torchvision import transforms as T from torchvision.transforms import functional as F -from otx.algorithms.common.utils.data import get_image +import otx.core.data.pipelines.load_image_from_otx_dataset as load_image_base from otx.api.utils.argument_checks import check_input_parameters_type from .dataset import get_annotation_mmseg_format -_CACHE_DIR = tempfile.TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with - +# pylint: disable=too-many-instance-attributes, too-many-arguments @PIPELINES.register_module() -class LoadImageFromOTXDataset: - """Pipeline element that loads an image from a OTX Dataset on the fly. Can do conversion to float 32 if needed. - - Expected entries in the 'results' dict that should be passed to this pipeline element are: - results['dataset_item']: dataset_item from which to load the image - results['dataset_id']: id of the dataset to which the item belongs - results['index']: index of the item in the dataset - - :param to_float32: optional bool, True to convert images to fp32. defaults to False - """ - - @check_input_parameters_type() - def __init__(self, to_float32: bool = False): - self.to_float32 = to_float32 - - @check_input_parameters_type() - def __call__(self, results: Dict[str, Any]): - """Callback function LoadImageFromOTXDataset.""" - # Get image (possibly from cache) - img = get_image(results, _CACHE_DIR.name, to_float32=self.to_float32) - shape = img.shape - - assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}" - assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}" - - filename = f"Dataset item index {results['index']}" - results["filename"] = filename - results["ori_filename"] = filename - results["img"] = img - results["img_shape"] = shape - results["ori_shape"] = shape - # Set initial values for default meta_keys - results["pad_shape"] = shape - num_channels = 1 if len(shape) < 3 else shape[2] - results["img_norm_cfg"] = dict( - mean=np.zeros(num_channels, dtype=np.float32), std=np.ones(num_channels, dtype=np.float32), to_rgb=False - ) - results["img_fields"] = ["img"] - - return results +class LoadImageFromOTXDataset(load_image_base.LoadImageFromOTXDataset): + """Pipeline element that loads an image from a OTX Dataset on the fly.""" @PIPELINES.register_module() diff --git a/otx/cli/manager/config_manager.py b/otx/cli/manager/config_manager.py index d2623988e58..534c38bec44 100644 --- a/otx/cli/manager/config_manager.py +++ b/otx/cli/manager/config_manager.py @@ -262,7 +262,13 @@ def get_hyparams_config(self) -> ConfigurableParameters: type_hint = gen_param_help(hyper_parameters) updated_hyper_parameters = gen_params_dict_from_args(self.args, type_hint=type_hint) override_parameters(updated_hyper_parameters, hyper_parameters) - return create(hyper_parameters) + + # (vinnamki) I added this line because these lines above looks like + # it parses out of the arguments, but I'm wondering if this is working. + hyper_parameters = create(hyper_parameters) + hyper_parameters.algo_backend.mem_cache_size = self.args.mem_cache_size + + return hyper_parameters def get_dataset_config(self, subsets: List[str]) -> dict: """Returns dataset_config in a format suitable for each subset. @@ -398,6 +404,11 @@ def build_workspace(self, new_workspace_path: Optional[str] = None) -> None: print(f"[*] Load Model Template ID: {self.template.model_template_id}") print(f"[*] Load Model Name: {self.template.name}") + # if self.args.mem_cache_size is not None: + # self.template.mem_cache_size = self.args.mem_cache_size + # # TODO: need a logger + # print(f"[*] Override mem_cache_size to {self.args.mem_cache_size}") + def _copy_config_files(self, target_dir: Path, file_name: str, dest_dir: Path) -> None: """Copy Configuration files for workspace.""" if (target_dir / file_name).exists(): diff --git a/otx/cli/tools/train.py b/otx/cli/tools/train.py index c11753b20ba..ce7c61aae14 100644 --- a/otx/cli/tools/train.py +++ b/otx/cli/tools/train.py @@ -33,6 +33,7 @@ from otx.cli.utils.io import read_binary, read_label_schema, save_model_data from otx.cli.utils.multi_gpu import MultiGPUManager from otx.cli.utils.parser import ( + MemSizeAction, add_hyper_parameters_sub_parser, get_parser_and_hprams_data, ) @@ -111,6 +112,16 @@ def get_args(): default=0, help="Total number of workers in a worker group.", ) + parser.add_argument( + "--mem-cache-size", + action=MemSizeAction, + type=str, + required=False, + default=0, + help="Size of memory pool for caching decoded data to load data faster. " + "For example, you can use digits (e.g. 1024) or a string with size units " + "(e.g. 7KB = 7 * 2^10, 3MB = 3 * 2^20, and 2GB = 2 * 2^30).", + ) sub_parser = add_hyper_parameters_sub_parser(parser, hyper_parameters, return_sub_parser=True) # TODO: Temporary solution for cases where there is no template input diff --git a/otx/cli/utils/parser.py b/otx/cli/utils/parser.py index 401e71cdb51..5028ca22933 100644 --- a/otx/cli/utils/parser.py +++ b/otx/cli/utils/parser.py @@ -15,12 +15,56 @@ # and limitations under the License. import argparse +import re from pathlib import Path from typing import Dict, Optional from otx.cli.registry import find_and_parse_model_template +class MemSizeAction(argparse.Action): + """Parser add on to parse memory size string.""" + + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super().__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + """Parse and set the attribute of namespace.""" + setattr(namespace, self.dest, self._parse_mem_size_str(values)) + + @staticmethod + def _parse_mem_size_str(mem_size: str) -> int: + assert isinstance(mem_size, str) + + match = re.match(r"^([\d\.]+)\s*([a-zA-Z]{0,3})$", mem_size.strip()) + + if match is None: + raise ValueError(f"Cannot parse {mem_size} string.") + + units = { + "": 1, + "B": 1, + "KB": 2**10, + "MB": 2**20, + "GB": 2**30, + "KIB": 10**3, + "MIB": 10**6, + "GIB": 10**9, + "K": 2**10, + "M": 2**20, + "G": 2**30, + } + + number, unit = int(match.group(1)), match.group(2).upper() + + if unit not in units: + raise ValueError(f"{mem_size} has disallowed unit ({unit}).") + + return number * units[unit] + + def gen_param_help(hyper_parameters): """Generates help for hyper parameters section.""" diff --git a/otx/core/data/pipelines/__init__.py b/otx/core/data/pipelines/__init__.py new file mode 100644 index 00000000000..699c2577892 --- /dev/null +++ b/otx/core/data/pipelines/__init__.py @@ -0,0 +1,3 @@ +"""OTX data pipelines.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/otx/core/data/pipelines/load_image_from_otx_dataset.py b/otx/core/data/pipelines/load_image_from_otx_dataset.py new file mode 100644 index 00000000000..5f2ea954d47 --- /dev/null +++ b/otx/core/data/pipelines/load_image_from_otx_dataset.py @@ -0,0 +1,72 @@ +"""Pipeline element that loads an image from a OTX Dataset on the fly.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from tempfile import TemporaryDirectory +from typing import Any, Dict + +import numpy as np + +from otx.algorithms.common.tools.caching import MemCacheHandlerSingleton +from otx.algorithms.common.utils.data import get_image +from otx.api.utils.argument_checks import check_input_parameters_type + +_CACHE_DIR = TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with + +# TODO: refactoring to common modules +# TODO: refactoring to Sphinx style. + + +class LoadImageFromOTXDataset: + """Pipeline element that loads an image from a OTX Dataset on the fly. + + Can do conversion to float 32 if needed. + Expected entries in the 'results' dict that should be passed to this pipeline element are: + results['dataset_item']: dataset_item from which to load the image + results['dataset_id']: id of the dataset to which the item belongs + results['index']: index of the item in the dataset + + :param to_float32: optional bool, True to convert images to fp32. defaults to False + """ + + @check_input_parameters_type() + def __init__(self, to_float32: bool = False): + self.to_float32 = to_float32 + self.mem_cache_handler = MemCacheHandlerSingleton.get() + + @check_input_parameters_type() + def __call__(self, results: Dict[str, Any]): + """Callback function of LoadImageFromOTXDataset.""" + key = results["dataset_item"].media.path + + img = self.mem_cache_handler.get(key) + + if img is None: + # Get image (possibly from cache) + img = get_image(results, _CACHE_DIR.name, to_float32=False) + self.mem_cache_handler.put(key, img) + + if self.to_float32: + img = img.astype(np.float32) + shape = img.shape + + assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}" + assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}" + + filename = f"Dataset item index {results['index']}" + results["filename"] = filename + results["ori_filename"] = filename + results["img"] = img + results["img_shape"] = shape + results["ori_shape"] = shape + # Set initial values for default meta_keys + results["pad_shape"] = shape + num_channels = 1 if len(shape) < 3 else shape[2] + results["img_norm_cfg"] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False, + ) + results["img_fields"] = ["img"] + + return results diff --git a/otx/mpa/modules/datasets/pipelines/caching/__init__.py b/otx/mpa/modules/datasets/pipelines/caching/__init__.py deleted file mode 100644 index e0978a47373..00000000000 --- a/otx/mpa/modules/datasets/pipelines/caching/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -from .load_image_from_file_with_cache import LoadImageFromFileWithCache -from .mem_cache_handler import MemCacheHandler - -__all__ = ["MemCacheHandler", "LoadImageFromFileWithCache"] diff --git a/otx/mpa/modules/datasets/pipelines/caching/load_image_from_file_with_cache.py b/otx/mpa/modules/datasets/pipelines/caching/load_image_from_file_with_cache.py deleted file mode 100644 index 2d062b7945d..00000000000 --- a/otx/mpa/modules/datasets/pipelines/caching/load_image_from_file_with_cache.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import os.path as osp - -import mmcv -import numpy as np -from mmdet.datasets.builder import PIPELINES -from mmdet.datasets.pipelines import LoadImageFromFile - -from .mem_cache_handler import MemCacheHandler - - -@PIPELINES.register_module() -class LoadImageFromFileWithCache(LoadImageFromFile): - """Load an image from file. - - Required keys are "img_prefix" and "img_info" (a dict that must contain the - key "filename"). Added or updated keys are "filename", "img", "img_shape", - "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), - "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). - - Args: - to_float32 (bool): Whether to convert the loaded image to a float32 - numpy array. If set to False, the loaded image is an uint8 array. - Defaults to False. - color_type (str): The flag argument for :func:`mmcv.imfrombytes`. - Defaults to 'color'. - file_client_args (dict): Arguments to instantiate a FileClient. - See :class:`mmcv.fileio.FileClient` for details. - Defaults to ``dict(backend='disk')``. - """ - - def __init__(self, to_float32=False, color_type="color", channel_order="bgr"): - self.to_float32 = to_float32 - self.color_type = color_type - self.channel_order = channel_order - self.file_client = mmcv.FileClient(backend="disk") - self.mem_cache_handler = MemCacheHandler() - - def __call__(self, results): - """Call functions to load image and get image meta information. - - Args: - results (dict): Result dict from :obj:`mmdet.CustomDataset`. - - Returns: - dict: The dict contains loaded image and meta information. - """ - - if results["img_prefix"] is not None: - filename = osp.join(results["img_prefix"], results["img_info"]["filename"]) - else: - filename = results["img_info"]["filename"] - - img = self.mem_cache_handler.get(key=filename) - - if img is None: - img_bytes = self.file_client.get(filename) - img = mmcv.imfrombytes(img_bytes, flag=self.color_type, channel_order=self.channel_order) - self.mem_cache_handler.put(key=filename, data=img) - - if self.to_float32: - img = img.astype(np.float32) - - results["filename"] = filename - results["ori_filename"] = results["img_info"]["filename"] - results["img"] = img - results["img_shape"] = img.shape - results["ori_shape"] = img.shape - results["img_fields"] = ["img"] - return results - - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f"to_float32={self.to_float32}, " - f"color_type='{self.color_type}', " - f"channel_order='{self.channel_order}', " - f"mem_cache_handler={self.mem_cache_handler})" - ) diff --git a/otx/mpa/modules/datasets/pipelines/caching/mem_cache_handler.py b/otx/mpa/modules/datasets/pipelines/caching/mem_cache_handler.py deleted file mode 100644 index 16cffb6d843..00000000000 --- a/otx/mpa/modules/datasets/pipelines/caching/mem_cache_handler.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import ctypes as ct -import multiprocessing as mp -import re -from typing import Optional - -import numpy as np - -from otx.mpa.utils.logger import get_logger - -logger = get_logger() - - -class _DummyLock: - def __enter__(self, *args, **kwargs): - pass - - def __exit__(self, *args, **kwargs): - pass - - -class MemCacheHandlerForSP: - def __init__(self, mem_size: int): - self._init_data_structs(mem_size) - - def _init_data_structs(self, mem_size: int): - self.arr = (ct.c_uint8 * mem_size)() - self.cur_page = ct.c_size_t(0) - self.cache_addr = {} - self.lock = _DummyLock() - - def __len__(self): - return len(self.cache_addr) - - @property - def mem_size(self) -> int: - return len(self.arr) - - def get(self, key: str) -> Optional[np.ndarray]: - if key not in self.cache_addr: - return None - - addr = self.cache_addr[key] - - offset, count, shape, strides = addr - - data = np.frombuffer(self.arr, dtype=np.uint8, count=count, offset=offset) - return np.lib.stride_tricks.as_strided(data, shape, strides) - - def put(self, key: str, data: np.ndarray) -> Optional[int]: - assert data.dtype == np.uint8 - - with self.lock: - new_page = self.cur_page.value + data.size - - if key in self.cache_addr or new_page > self.mem_size: - return None - - offset = ct.byref(self.arr, self.cur_page.value) - ct.memmove(offset, data.ctypes.data, data.size) - - self.cache_addr[key] = ( - self.cur_page.value, - data.size, - data.shape, - data.strides, - ) - self.cur_page.value = new_page - return new_page - - def __repr__(self): - return ( - f"{self.__class__.__name__} " - f"uses {self.cur_page.value} / {self.mem_size} memory pool and " - f"store {len(self)} items." - ) - - -class MemCacheHandlerForMP(MemCacheHandlerForSP): - def __init__(self, mem_size: int): - super().__init__(mem_size) - - def _init_data_structs(self, mem_size: int): - self.arr = mp.Array(ct.c_uint8, mem_size, lock=False) - self.cur_page = mp.Value(ct.c_size_t, 0, lock=False) - - self.manager = mp.Manager() - self.cache_addr = self.manager.dict() - self.lock = mp.Lock() - - def __del__(self): - self.manager.shutdown() - - -class MemCacheHandler(MemCacheHandlerForSP): - instance = Optional[MemCacheHandlerForSP] - - def __init__(self): - pass - - def __new__(cls) -> Optional[MemCacheHandlerForSP]: - if not hasattr(cls, "instance"): - raise RuntimeError(f"Before calling {cls.__name__}(), you should call {cls.__name__}.create() first.") - - return cls.instance - - @classmethod - def create(cls, mode: str, mem_size: str) -> Optional[MemCacheHandlerForSP]: - mem_size = cls._parse_mem_size_str(mem_size) - logger.info(f"Try to create a {mem_size} size memory pool.") - - if mode == "multiprocessing": - cls.instance = MemCacheHandlerForMP(mem_size) - elif mode == "singleprocessing": - cls.instance = MemCacheHandlerForSP(mem_size) - else: - raise ValueError(f"{mode} is unknown mode.") - - return cls.instance - - @staticmethod - def _parse_mem_size_str(mem_size: str) -> int: - assert isinstance(mem_size, str) - - m = re.match(r"^([\d\.]+)\s*([a-zA-Z]{0,3})$", mem_size.strip()) - - if m is None: - raise ValueError(f"Cannot parse {mem_size} string.") - - units = { - "": 1, - "B": 1, - "KB": 2**10, - "MB": 2**20, - "GB": 2**30, - "KIB": 10**3, - "MIB": 10**6, - "GIB": 10**9, - "K": 2**10, - "M": 2**20, - "G": 2**30, - } - - number, unit = int(m.group(1)), m.group(2).upper() - - if unit not in units: - raise ValueError(f"{mem_size} has disallowed unit ({unit}).") - - return number * units[unit] diff --git a/otx/mpa/modules/hooks/eval_hook.py b/otx/mpa/modules/hooks/eval_hook.py index 920c5839a9b..eed4e743996 100644 --- a/otx/mpa/modules/hooks/eval_hook.py +++ b/otx/mpa/modules/hooks/eval_hook.py @@ -96,6 +96,7 @@ def single_gpu_test(model, data_loader): batch_size = data["img"].size(0) for _ in range(batch_size): prog_bar.update() + prog_bar.file.write("\n") return results From bc367ab95a01cc2217d214eac9352711132bc799 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 14 Feb 2023 16:46:24 +0900 Subject: [PATCH 03/15] Refactor and fix unit tests Signed-off-by: Kim, Vinnam --- otx/algorithms/common/tasks/training_base.py | 2 +- otx/api/entities/dataset_item.py | 8 +- .../tools => core/data}/caching/__init__.py | 0 .../data}/caching/mem_cache_handler.py | 2 +- .../data}/caching/mem_cache_hook.py | 0 .../pipelines/load_image_from_otx_dataset.py | 15 +- tests/unit/cli/utils/test_parser.py | 47 +++++ tests/unit/core/__init__.py | 3 + tests/unit/core/data/__init__.py | 3 + tests/unit/core/data/test_caching.py | 145 +++++++++++++++ tests/unit/mpa/test_caching.py | 165 ------------------ 11 files changed, 215 insertions(+), 175 deletions(-) rename otx/{algorithms/common/tools => core/data}/caching/__init__.py (100%) rename otx/{algorithms/common/tools => core/data}/caching/mem_cache_handler.py (98%) rename otx/{algorithms/common/tools => core/data}/caching/mem_cache_hook.py (100%) create mode 100644 tests/unit/cli/utils/test_parser.py create mode 100644 tests/unit/core/__init__.py create mode 100644 tests/unit/core/data/__init__.py create mode 100644 tests/unit/core/data/test_caching.py delete mode 100644 tests/unit/mpa/test_caching.py diff --git a/otx/algorithms/common/tasks/training_base.py b/otx/algorithms/common/tasks/training_base.py index a624eeefbe7..f1b8b702326 100644 --- a/otx/algorithms/common/tasks/training_base.py +++ b/otx/algorithms/common/tasks/training_base.py @@ -34,7 +34,6 @@ get_configs_by_pairs, ) from otx.algorithms.common.configs import TrainType -from otx.algorithms.common.tools import caching from otx.algorithms.common.utils import UncopiableDefaultDict from otx.api.entities.datasets import DatasetEntity from otx.api.entities.label import LabelEntity @@ -47,6 +46,7 @@ from otx.api.usecases.tasks.interfaces.inference_interface import IInferenceTask from otx.api.usecases.tasks.interfaces.unload_interface import IUnload from otx.api.utils.argument_checks import check_input_parameters_type +from otx.core.data import caching from otx.mpa.builder import build from otx.mpa.modules.hooks.cancel_interface_hook import CancelInterfaceHook from otx.mpa.stage import Stage diff --git a/otx/api/entities/dataset_item.py b/otx/api/entities/dataset_item.py index 762db27f2b4..ba542ebcbc6 100644 --- a/otx/api/entities/dataset_item.py +++ b/otx/api/entities/dataset_item.py @@ -110,6 +110,8 @@ def __init__( if Rectangle.is_full_box(annotation.shape): roi = annotation break + if roi is None: + roi = Annotation(Rectangle.generate_full_box(), labels=[]) self.__roi = roi self.__metadata: List[MetadataItemEntity] = [] @@ -150,11 +152,7 @@ def __repr__(self): def roi(self) -> Annotation: """Region Of Interest.""" with self.__roi_lock: - if self.__roi is None: - requested_roi = Annotation(Rectangle.generate_full_box(), labels=[]) - self.__roi = requested_roi - else: - requested_roi = self.__roi + requested_roi = self.__roi return requested_roi @roi.setter diff --git a/otx/algorithms/common/tools/caching/__init__.py b/otx/core/data/caching/__init__.py similarity index 100% rename from otx/algorithms/common/tools/caching/__init__.py rename to otx/core/data/caching/__init__.py diff --git a/otx/algorithms/common/tools/caching/mem_cache_handler.py b/otx/core/data/caching/mem_cache_handler.py similarity index 98% rename from otx/algorithms/common/tools/caching/mem_cache_handler.py rename to otx/core/data/caching/mem_cache_handler.py index 1cfd658df61..3f1eda2031f 100644 --- a/otx/algorithms/common/tools/caching/mem_cache_handler.py +++ b/otx/core/data/caching/mem_cache_handler.py @@ -103,7 +103,7 @@ def put(self, key: Any, data: np.ndarray) -> Optional[int]: def __repr__(self): """Representation for the current handler status.""" - perc = 100.0 * self._cur_page.value / self.mem_size + perc = 100.0 * self._cur_page.value / self.mem_size if self.mem_size > 0 else 0.0 return ( f"{self.__class__.__name__} " f"uses {self._cur_page.value} / {self.mem_size} ({perc:.1f}%) memory pool and " diff --git a/otx/algorithms/common/tools/caching/mem_cache_hook.py b/otx/core/data/caching/mem_cache_hook.py similarity index 100% rename from otx/algorithms/common/tools/caching/mem_cache_hook.py rename to otx/core/data/caching/mem_cache_hook.py diff --git a/otx/core/data/pipelines/load_image_from_otx_dataset.py b/otx/core/data/pipelines/load_image_from_otx_dataset.py index 5f2ea954d47..281554d9eb2 100644 --- a/otx/core/data/pipelines/load_image_from_otx_dataset.py +++ b/otx/core/data/pipelines/load_image_from_otx_dataset.py @@ -3,14 +3,15 @@ # SPDX-License-Identifier: Apache-2.0 from tempfile import TemporaryDirectory -from typing import Any, Dict +from typing import Any, Dict, Tuple import numpy as np -from otx.algorithms.common.tools.caching import MemCacheHandlerSingleton from otx.algorithms.common.utils.data import get_image from otx.api.utils.argument_checks import check_input_parameters_type +from ..caching import MemCacheHandlerSingleton + _CACHE_DIR = TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with # TODO: refactoring to common modules @@ -34,10 +35,18 @@ def __init__(self, to_float32: bool = False): self.to_float32 = to_float32 self.mem_cache_handler = MemCacheHandlerSingleton.get() + @staticmethod + def _get_unique_key(results: Dict[str, Any]) -> Tuple: + # TODO: We should improve it by assigning an unique id to DatasetItemEntity. + # This is because there is a case which + # d_item.media.path is None, but d_item.media.data is not None + d_item = results["dataset_item"] + return d_item.media.path, d_item.roi.id + @check_input_parameters_type() def __call__(self, results: Dict[str, Any]): """Callback function of LoadImageFromOTXDataset.""" - key = results["dataset_item"].media.path + key = self._get_unique_key(results) img = self.mem_cache_handler.get(key) diff --git a/tests/unit/cli/utils/test_parser.py b/tests/unit/cli/utils/test_parser.py new file mode 100644 index 00000000000..cecf6d091d6 --- /dev/null +++ b/tests/unit/cli/utils/test_parser.py @@ -0,0 +1,47 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import argparse + +import pytest + +from otx.cli.utils.parser import MemSizeAction + + +@pytest.fixture +def fxt_argparse(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--mem-cache-size", + action=MemSizeAction, + type=str, + required=False, + default=0, + ) + return parser + + +@pytest.mark.parametrize( + "mem_size_arg,expected", + [ + ("1561", 1561), + ("121k", 121 * (2**10)), + ("121kb", 121 * (2**10)), + ("121kib", 121 * (10**3)), + ("121m", 121 * (2**20)), + ("121mb", 121 * (2**20)), + ("121mib", 121 * (10**6)), + ("121g", 121 * (2**30)), + ("121gb", 121 * (2**30)), + ("121gib", 121 * (10**9)), + ("121as", None), + ("121dddd", None), + ], +) +def test_mem_size_parsing(fxt_argparse, mem_size_arg, expected): + try: + args = fxt_argparse.parse_args(["--mem-cache-size", mem_size_arg]) + assert args.mem_cache_size == expected + except ValueError: + assert expected is None diff --git a/tests/unit/core/__init__.py b/tests/unit/core/__init__.py new file mode 100644 index 00000000000..9c68be83ef0 --- /dev/null +++ b/tests/unit/core/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/tests/unit/core/data/__init__.py b/tests/unit/core/data/__init__.py new file mode 100644 index 00000000000..9c68be83ef0 --- /dev/null +++ b/tests/unit/core/data/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/tests/unit/core/data/test_caching.py b/tests/unit/core/data/test_caching.py new file mode 100644 index 00000000000..3a49969c25a --- /dev/null +++ b/tests/unit/core/data/test_caching.py @@ -0,0 +1,145 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import string +from unittest.mock import patch + +import numpy as np +import pytest +from torch.utils.data import DataLoader, Dataset + +from otx.api.entities.annotation import AnnotationSceneEntity, AnnotationSceneKind +from otx.api.entities.dataset_item import DatasetItemEntity +from otx.api.entities.image import Image +from otx.core.data.caching import MemCacheHandlerSingleton +from otx.core.data.pipelines.load_image_from_otx_dataset import LoadImageFromOTXDataset + + +@pytest.fixture +def fxt_data_list(): + np.random.seed(3003) + + num_data = 10 + h = w = key_len = 16 + + data_list = [] + for _ in range(num_data): + data = np.random.randint(0, 256, size=[h, w, 3], dtype=np.uint8) + key = "".join( + [string.ascii_lowercase[i] for i in np.random.randint(0, len(string.ascii_lowercase), size=[key_len])] + ) + data_list += [(key, data)] + + return data_list + + +@pytest.fixture +def fxt_caching_dataset_cls(fxt_data_list: list): + class CachingDataset(Dataset): + def __init__(self) -> None: + super().__init__() + self.d_items = [ + DatasetItemEntity( + media=Image(data=data), + annotation_scene=AnnotationSceneEntity(annotations=[], kind=AnnotationSceneKind.ANNOTATION), + ) + for _, data in fxt_data_list + ] + self.load = LoadImageFromOTXDataset() + + def __len__(self): + return len(self.d_items) + + def __getitem__(self, index): + d_item = self.d_items[index] + + results = { + "dataset_item": d_item, + "height": d_item.media.numpy.shape[0], + "width": d_item.media.numpy.shape[1], + "index": index, + } + + results = self.load(results) + return results["img"] + + yield CachingDataset + + +def get_data_list_size(data_list): + size = 0 + for _, data in data_list: + size += data.size + return size + + +class TestMemCacheHandler: + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + def test_fully_caching(self, mode, fxt_data_list): + mem_size = get_data_list_size(fxt_data_list) + MemCacheHandlerSingleton.create(mode, mem_size) + handler = MemCacheHandlerSingleton.get() + + for key, data in fxt_data_list: + assert handler.put(key, data) > 0 + + for key, data in fxt_data_list: + get_data = handler.get(key) + + assert np.array_equal(get_data, data) + + # Fully cached + assert len(handler) == len(fxt_data_list) + + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + def test_unfully_caching(self, mode, fxt_data_list): + mem_size = get_data_list_size(fxt_data_list) // 2 + MemCacheHandlerSingleton.create(mode, mem_size) + handler = MemCacheHandlerSingleton.get() + + for idx, (key, data) in enumerate(fxt_data_list): + if idx < len(fxt_data_list) // 2: + assert handler.put(key, data) > 0 + else: + assert handler.put(key, data) is None + + for idx, (key, data) in enumerate(fxt_data_list): + get_data = handler.get(key) + + if idx < len(fxt_data_list) // 2: + assert np.array_equal(get_data, data) + else: + assert get_data is None + + # Unfully (half) cached + assert len(handler) == len(fxt_data_list) // 2 + + +class TestLoadImageFromFileWithCache: + @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) + def test_combine_with_dataloader(self, mode, fxt_caching_dataset_cls, fxt_data_list): + mem_size = get_data_list_size(fxt_data_list) + MemCacheHandlerSingleton.create(mode, mem_size) + + dataset = fxt_caching_dataset_cls() + + with patch( + "otx.core.data.pipelines.load_image_from_otx_dataset.get_image", + side_effect=[data for _, data in fxt_data_list], + ) as mock: + for _ in DataLoader(dataset): + continue + + # This initial round requires all data samples to be put(). + assert mock.call_count == len(dataset) + + with patch( + "otx.core.data.pipelines.load_image_from_otx_dataset.get_image", + side_effect=[data for _, data in fxt_data_list], + ) as mock: + for _ in DataLoader(dataset): + continue + + # The second round requires no put(). + assert mock.call_count == 0 diff --git a/tests/unit/mpa/test_caching.py b/tests/unit/mpa/test_caching.py deleted file mode 100644 index ab8e202dde9..00000000000 --- a/tests/unit/mpa/test_caching.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -import os.path as osp -import string -from tempfile import TemporaryDirectory - -import cv2 -import numpy as np -import pytest -from torch.utils.data import DataLoader, Dataset - -from otx.mpa.modules.datasets.pipelines.caching import ( - LoadImageFromFileWithCache, - MemCacheHandler, -) - - -@pytest.fixture -def fxt_data_list(): - np.random.seed(3003) - - num_data = 10 - h = w = key_len = 16 - - data_list = [] - for _ in range(num_data): - data = np.random.randint(0, 256, size=[h, w, 3], dtype=np.uint8) - key = "".join( - [string.ascii_lowercase[i] for i in np.random.randint(0, len(string.ascii_lowercase), size=[key_len])] - ) - data_list += [(key, data)] - - return data_list - - -@pytest.fixture -def fxt_caching_dataset_cls(fxt_data_list): - with TemporaryDirectory() as img_prefix: - for key, data in fxt_data_list: - cv2.imwrite(osp.join(img_prefix, key + ".png"), data) - - class CachingDataset(Dataset): - def __init__(self) -> None: - super().__init__() - self.data_list = fxt_data_list - self.load = LoadImageFromFileWithCache() - self.file_get_count = 0 - - __get = self.load.file_client.get - - def _get(filepath): - self.file_get_count += 1 - return __get(filepath) - - self.load.file_client.get = _get - - def reset_file_count(self): - self.file_get_count = 0 - - def __len__(self): - return len(self.data_list) - - def __getitem__(self, index): - key, _ = self.data_list[index] - results = { - "img_prefix": img_prefix, - "img_info": {"filename": key + ".png"}, - } - return self.load(results) - - yield CachingDataset - - -def get_data_list_size(data_list): - size = 0 - for _, data in data_list: - size += data.size - return size - - -class TestMemCacheHandler: - @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) - def test_fully_caching(self, mode, fxt_data_list): - mem_size = str(get_data_list_size(fxt_data_list)) - MemCacheHandler.create(mode, mem_size) - handler = MemCacheHandler() - - for key, data in fxt_data_list: - assert handler.put(key, data) > 0 - - for key, data in fxt_data_list: - get_data = handler.get(key) - - assert np.array_equal(get_data, data) - - # Fully cached - assert len(handler) == len(fxt_data_list) - - @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) - def test_unfully_caching(self, mode, fxt_data_list): - mem_size = str(get_data_list_size(fxt_data_list) // 2) - MemCacheHandler.create(mode, mem_size) - handler = MemCacheHandler() - - for idx, (key, data) in enumerate(fxt_data_list): - if idx < len(fxt_data_list) // 2: - assert handler.put(key, data) > 0 - else: - assert handler.put(key, data) is None - - for idx, (key, data) in enumerate(fxt_data_list): - get_data = handler.get(key) - - if idx < len(fxt_data_list) // 2: - assert np.array_equal(get_data, data) - else: - assert get_data is None - - # Unfully (half) cached - assert len(handler) == len(fxt_data_list) // 2 - - @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) - @pytest.mark.parametrize( - "mem_size,expected", - [ - ("1561", 1561), - ("121k", 121 * (2**10)), - ("121kb", 121 * (2**10)), - ("121kib", 121 * (10**3)), - ("121as", None), - ("121dddd", None), - ], - ) - def test_mem_size_parsing(self, mode, mem_size, expected): - try: - MemCacheHandler.create(mode, mem_size) - handler = MemCacheHandler() - assert handler.mem_size == expected - except ValueError: - assert expected is None - - -class TestLoadImageFromFileWithCache: - @pytest.mark.parametrize("mode", ["singleprocessing", "multiprocessing"]) - def test_combine_with_dataloader(self, mode, fxt_caching_dataset_cls, fxt_data_list): - mem_size = str(get_data_list_size(fxt_data_list)) - MemCacheHandler.create(mode, mem_size) - - dataset = fxt_caching_dataset_cls() - - for _ in DataLoader(dataset): - continue - - # This initial round requires file_client.get() for all data samples. - assert dataset.file_get_count == len(dataset) - - dataset.reset_file_count() - - for _ in DataLoader(dataset): - continue - - # The second round requires no file_client.get(). - assert dataset.file_get_count == 0 From 06d3eaae4c2b9edf9c4c06cfb464d1fa9a58838e Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 14 Feb 2023 17:20:57 +0900 Subject: [PATCH 04/15] Update configuration.yaml for other tasks Signed-off-by: Kim, Vinnam --- .../classification/configs/configuration.yaml | 4 ++-- .../configs/detection/configuration.yaml | 16 ++++++++++++++++ .../instance_segmentation/configuration.yaml | 16 ++++++++++++++++ .../segmentation/configs/configuration.yaml | 16 ++++++++++++++++ otx/cli/manager/config_manager.py | 5 ----- otx/cli/tools/train.py | 2 +- otx/core/data/caching/mem_cache_handler.py | 2 +- 7 files changed, 52 insertions(+), 9 deletions(-) diff --git a/otx/algorithms/classification/configs/configuration.yaml b/otx/algorithms/classification/configs/configuration.yaml index ef6585ab73e..897c3f7e13f 100644 --- a/otx/algorithms/classification/configs/configuration.yaml +++ b/otx/algorithms/classification/configs/configuration.yaml @@ -357,9 +357,9 @@ algo_backend: mem_cache_size: affects_outcome_of: TRAINING default_value: 0 - description: Size of memory pool for caching decoded data to load data faster + description: Size of memory pool for caching decoded data to load data faster (bytes). editable: true - header: Size of memory pool for caching decoded data to load data faster + header: Size of memory pool max_value: 9223372036854775807 min_value: 0 type: INTEGER diff --git a/otx/algorithms/detection/configs/detection/configuration.yaml b/otx/algorithms/detection/configs/detection/configuration.yaml index 749ce505b4e..cd8ba0eadff 100644 --- a/otx/algorithms/detection/configs/detection/configuration.yaml +++ b/otx/algorithms/detection/configs/detection/configuration.yaml @@ -262,6 +262,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml b/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml index 51d7e9d3696..57693128302 100644 --- a/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml +++ b/otx/algorithms/detection/configs/instance_segmentation/configuration.yaml @@ -262,6 +262,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/algorithms/segmentation/configs/configuration.yaml b/otx/algorithms/segmentation/configs/configuration.yaml index 689a77855c5..0da91d335ba 100644 --- a/otx/algorithms/segmentation/configs/configuration.yaml +++ b/otx/algorithms/segmentation/configs/configuration.yaml @@ -292,6 +292,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/cli/manager/config_manager.py b/otx/cli/manager/config_manager.py index 534c38bec44..24663753ed2 100644 --- a/otx/cli/manager/config_manager.py +++ b/otx/cli/manager/config_manager.py @@ -404,11 +404,6 @@ def build_workspace(self, new_workspace_path: Optional[str] = None) -> None: print(f"[*] Load Model Template ID: {self.template.model_template_id}") print(f"[*] Load Model Name: {self.template.name}") - # if self.args.mem_cache_size is not None: - # self.template.mem_cache_size = self.args.mem_cache_size - # # TODO: need a logger - # print(f"[*] Override mem_cache_size to {self.args.mem_cache_size}") - def _copy_config_files(self, target_dir: Path, file_name: str, dest_dir: Path) -> None: """Copy Configuration files for workspace.""" if (target_dir / file_name).exists(): diff --git a/otx/cli/tools/train.py b/otx/cli/tools/train.py index ce7c61aae14..4e5e6ecaa7c 100644 --- a/otx/cli/tools/train.py +++ b/otx/cli/tools/train.py @@ -119,7 +119,7 @@ def get_args(): required=False, default=0, help="Size of memory pool for caching decoded data to load data faster. " - "For example, you can use digits (e.g. 1024) or a string with size units " + "For example, you can use digits for bytes size (e.g. 1024) or a string with size units " "(e.g. 7KB = 7 * 2^10, 3MB = 3 * 2^20, and 2GB = 2 * 2^30).", ) diff --git a/otx/core/data/caching/mem_cache_handler.py b/otx/core/data/caching/mem_cache_handler.py index 3f1eda2031f..351658fa192 100644 --- a/otx/core/data/caching/mem_cache_handler.py +++ b/otx/core/data/caching/mem_cache_handler.py @@ -58,7 +58,7 @@ def get(self, key: Any) -> Optional[np.ndarray]: Returns: If succeed return np.ndarray, otherwise return None """ - if key not in self._cache_addr: + if self.mem_size == 0 or key not in self._cache_addr: return None addr = self._cache_addr[key] From 93cb94086df8ca3ed8f875eef5e06e745d53d838 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 14 Feb 2023 17:31:25 +0900 Subject: [PATCH 05/15] Update QUICK_START_GUIDE.md Signed-off-by: Kim, Vinnam --- QUICK_START_GUIDE.md | 60 ++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/QUICK_START_GUIDE.md b/QUICK_START_GUIDE.md index 46c4c0ad900..a3462d04893 100644 --- a/QUICK_START_GUIDE.md +++ b/QUICK_START_GUIDE.md @@ -104,8 +104,10 @@ And with the `--help` command along with `template`, you can list additional inf ```bash # Command example to get common paramters to any model templates (otx) ...$ otx train otx/algorithms/detection/configs/detection/mobilenetv2_ssd/template.yaml --help -usage: otx train [-h] --train-ann-files TRAIN_ANN_FILES --train-data-roots TRAIN_DATA_ROOTS --val-ann-files VAL_ANN_FILES --val-data-roots VAL_DATA_ROOTS [--load-weights LOAD_WEIGHTS] --save-model-to SAVE_MODEL_TO - [--enable-hpo] [--hpo-time-ratio HPO_TIME_RATIO] +usage: otx train [-h] [--train-data-roots TRAIN_DATA_ROOTS] [--val-data-roots VAL_DATA_ROOTS] [--unlabeled-data-roots UNLABELED_DATA_ROOTS] + [--unlabeled-file-list UNLABELED_FILE_LIST] [--load-weights LOAD_WEIGHTS] [--resume-from RESUME_FROM] [--save-model-to SAVE_MODEL_TO] [--work-dir WORK_DIR] + [--enable-hpo] [--hpo-time-ratio HPO_TIME_RATIO] [--gpus GPUS] [--rdzv-endpoint RDZV_ENDPOINT] [--base-rank BASE_RANK] [--world-size WORLD_SIZE] + [--mem-cache-size MEM_CACHE_SIZE] template {params} ... positional arguments: @@ -115,23 +117,34 @@ positional arguments: optional arguments: -h, --help show this help message and exit - --train-ann-files TRAIN_ANN_FILES - Comma-separated paths to training annotation files. --train-data-roots TRAIN_DATA_ROOTS Comma-separated paths to training data folders. - --val-ann-files VAL_ANN_FILES - Comma-separated paths to validation annotation files. --val-data-roots VAL_DATA_ROOTS Comma-separated paths to validation data folders. + --unlabeled-data-roots UNLABELED_DATA_ROOTS + Comma-separated paths to unlabeled data folders + --unlabeled-file-list UNLABELED_FILE_LIST + Comma-separated paths to unlabeled file list --load-weights LOAD_WEIGHTS - Load only weights from previously saved checkpoint + Load model weights from previously saved checkpoint. --resume-from RESUME_FROM - Resume training from previously saved checkpoint + Resume training from previously saved checkpoint --save-model-to SAVE_MODEL_TO Location where trained model will be stored. + --work-dir WORK_DIR Location where the intermediate output of the training will be stored. --enable-hpo Execute hyper parameters optimization (HPO) before training. --hpo-time-ratio HPO_TIME_RATIO Expected ratio of total time to run HPO to time taken for full fine-tuning. + --gpus GPUS Comma-separated indices of GPU. If there are more than one available GPU, then model is trained with multi GPUs. + --rdzv-endpoint RDZV_ENDPOINT + Rendezvous endpoint for multi-node training. + --base-rank BASE_RANK + Base rank of the current node workers. + --world-size WORLD_SIZE + Total number of workers in a worker group. + --mem-cache-size MEM_CACHE_SIZE + Size of memory pool for caching decoded data to load data faster. For example, you can use digits for bytes size (e.g. 1024) or a string with size + units (e.g. 7KB = 7 * 2^10, 3MB = 3 * 2^20, and 2GB = 2 * 2^30). ``` #### Model template-specific parameters @@ -176,26 +189,29 @@ optional arguments: #### Command example of the training ```bash -(otx) ...$ otx train otx/algorithms/detection/configs/detection/mobilenetv2_ssd/template.yaml --train-ann-file data/airport/annotation_person_train.json --train-data-roots data/airport/train/ --val-ann-files data/airport/annotation_person_val.json --val-data-roots data/airport/val/ --save-model-to outputs +(otx) ...$ otx train otx/algorithms/detection/configs/detection/mobilenetv2_ssd/template.yaml --train-data-roots tests/assets/car_tree_bug --val-data-roots tests/assets/car_tree_bug --save-model-to outputs --mem-cache-size 64MB ... ---------------iou_thr: 0.5--------------- -+--------+-----+------+--------+-------+ -| class | gts | dets | recall | ap | -+--------+-----+------+--------+-------+ -| person | 0 | 2000 | 0.000 | 0.000 | -+--------+-----+------+--------+-------+ -| mAP | | | | 0.000 | -+--------+-----+------+--------+-------+ -2022-11-17 11:08:15,245 | INFO : run task done. -2022-11-17 11:08:15,318 | INFO : Inference completed -2022-11-17 11:08:15,319 | INFO : called evaluate() -2022-11-17 11:08:15,334 | INFO : F-measure after evaluation: 0.8809523809523808 -2022-11-17 11:08:15,334 | INFO : Evaluation completed -Performance(score: 0.8809523809523808, dashboard: (1 metric groups)) ++-------+-----+------+--------+-------+ +| class | gts | dets | recall | ap | ++-------+-----+------+--------+-------+ +| car | 7 | 530 | 1.000 | 0.571 | +| tree | 7 | 585 | 1.000 | 0.929 | +| bug | 8 | 485 | 1.000 | 0.805 | ++-------+-----+------+--------+-------+ +| mAP | | | | 0.768 | ++-------+-----+------+--------+-------+ +2023-02-14 17:27:35,707 | INFO : Inference completed +2023-02-14 17:27:35,707 | INFO : called evaluate() +2023-02-14 17:27:35,714 | INFO : F-measure after evaluation: 0.6285714285714284 +2023-02-14 17:27:35,715 | INFO : Evaluation completed +Performance(score: 0.6285714285714284, dashboard: (1 metric groups)) ``` +- `--mem-cache-size 64MB` reserves a memory pool of 64 Megabytes in main memory. It is used for caching decoded data. Setting this value as high as possible can improve data load performance in training. + ### Exporting `export` exports a trained model to the OpenVINO format in order to efficiently run it on Intel hardware. From ff598794c7d53dc24ec71c94783204412c2b299d Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 14 Feb 2023 17:40:41 +0900 Subject: [PATCH 06/15] Fix comments Signed-off-by: Kim, Vinnam --- tests/unit/core/data/test_caching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/core/data/test_caching.py b/tests/unit/core/data/test_caching.py index 3a49969c25a..fa7ae67dd2c 100644 --- a/tests/unit/core/data/test_caching.py +++ b/tests/unit/core/data/test_caching.py @@ -131,7 +131,7 @@ def test_combine_with_dataloader(self, mode, fxt_caching_dataset_cls, fxt_data_l for _ in DataLoader(dataset): continue - # This initial round requires all data samples to be put(). + # This initial round requires all data samples to be read from disk. assert mock.call_count == len(dataset) with patch( @@ -141,5 +141,5 @@ def test_combine_with_dataloader(self, mode, fxt_caching_dataset_cls, fxt_data_l for _ in DataLoader(dataset): continue - # The second round requires no put(). + # The second round requires no read. assert mock.call_count == 0 From 7cd7166f42cfc8b8904556fd9e0fde1c9fe25a19 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 14 Feb 2023 17:56:15 +0900 Subject: [PATCH 07/15] Rollback DatasetItemEntity.roi Signed-off-by: Kim, Vinnam --- otx/api/entities/dataset_item.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/otx/api/entities/dataset_item.py b/otx/api/entities/dataset_item.py index ba542ebcbc6..7e708611195 100644 --- a/otx/api/entities/dataset_item.py +++ b/otx/api/entities/dataset_item.py @@ -152,7 +152,11 @@ def __repr__(self): def roi(self) -> Annotation: """Region Of Interest.""" with self.__roi_lock: - requested_roi = self.__roi + if self.__roi is None: + requested_roi = Annotation(Rectangle.generate_full_box(), labels=[]) + self.__roi = requested_roi + else: + requested_roi = self.__roi return requested_roi @roi.setter From 90f8c99f017a10ed649969e9e44d3bef27f8641c Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 14 Feb 2023 17:59:25 +0900 Subject: [PATCH 08/15] Fix DatasetItemEntity.roi getter and setter Signed-off-by: Kim, Vinnam --- otx/api/entities/dataset_item.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/otx/api/entities/dataset_item.py b/otx/api/entities/dataset_item.py index 7e708611195..7976a03d2de 100644 --- a/otx/api/entities/dataset_item.py +++ b/otx/api/entities/dataset_item.py @@ -152,16 +152,13 @@ def __repr__(self): def roi(self) -> Annotation: """Region Of Interest.""" with self.__roi_lock: - if self.__roi is None: - requested_roi = Annotation(Rectangle.generate_full_box(), labels=[]) - self.__roi = requested_roi - else: - requested_roi = self.__roi - return requested_roi + return self.__roi @roi.setter def roi(self, roi: Optional[Annotation]): with self.__roi_lock: + if roi is None: + roi = Annotation(Rectangle.generate_full_box(), labels=[]) self.__roi = roi @property From b1535313edcddc880043bbd1ec9c2bcbbe28b754 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Wed, 15 Feb 2023 11:42:29 +0900 Subject: [PATCH 09/15] Add yaml recipes to package_data Signed-off-by: Kim, Vinnam --- MANIFEST.in | 1 + setup.py | 32 ++++++++++++++++++++++---------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 849f47e5f3a..229aecc7bd1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ recursive-include requirements * recursive-include otx *.pyx recursive-exclude otx *.c *.html +recursive-include otx/recipes *.yaml diff --git a/setup.py b/setup.py index 098624d432a..1583bb5ec26 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ import subprocess import sys import warnings +from collections import defaultdict from glob import glob from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path @@ -19,7 +20,7 @@ from setuptools import Extension, find_packages, setup try: - from torch.utils.cpp_extension import CppExtension, BuildExtension + from torch.utils.cpp_extension import BuildExtension, CppExtension cmd_class = {"build_ext": BuildExtension} except ModuleNotFoundError: @@ -82,9 +83,7 @@ def get_requirements(requirement_files: Union[str, List[str]]) -> List[str]: requirements: List[str] = [] for requirement_file in requirement_files: - with open( - f"requirements/{requirement_file}.txt", "r", encoding="UTF-8" - ) as file: + with open(f"requirements/{requirement_file}.txt", "r", encoding="UTF-8") as file: for line in file: package = line.strip() if package and not package.startswith(("#", "-f")): @@ -127,9 +126,7 @@ def _cython_modules(): "classification": get_requirements(requirement_files="classification"), "detection": get_requirements(requirement_files="detection"), "segmentation": get_requirements(requirement_files="segmentation"), - "mpa": get_requirements( - requirement_files=["classification", "detection", "segmentation", "action"] - ), + "mpa": get_requirements(requirement_files=["classification", "detection", "segmentation", "action"]), "full": get_requirements( requirement_files=[ "anomaly", @@ -142,13 +139,28 @@ def _cython_modules(): } +def find_yaml_recipes(): + """Find YAML recipe files in the package.""" + results = defaultdict(list) + + for root, _, files in os.walk(os.path.join("otx", "recipes")): + module = ".".join(root.split(os.sep)) + for file in files: + _, ext = os.path.splitext(file) + if ext == ".yaml": + results[module] += [file] + + return results + + +package_data = {"": ["requirements.txt", "README.md", "LICENSE"]} # Needed for exportable code +package_data.update(find_yaml_recipes()) + setup( name="otx", version=get_otx_version(), packages=find_packages(exclude=("tests",)), - package_data={ - "": ["requirements.txt", "README.md", "LICENSE"] - }, # Needed for exportable code + package_data=package_data, ext_modules=get_extensions(), cmdclass=cmd_class, install_requires=REQUIRED_PACKAGES, From 6116934f652c6b7ea965b458302eb061e783e475 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Wed, 15 Feb 2023 12:30:46 +0900 Subject: [PATCH 10/15] Fix Codacy error Signed-off-by: Kim, Vinnam --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1583bb5ec26..733cdb11fad 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ from setuptools import Extension, find_packages, setup try: - from torch.utils.cpp_extension import BuildExtension, CppExtension + from torch.utils.cpp_extension import BuildExtension cmd_class = {"build_ext": BuildExtension} except ModuleNotFoundError: From 837b1c4ada5eb9e941ba897745f2adeedcb3553d Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Wed, 15 Feb 2023 12:35:59 +0900 Subject: [PATCH 11/15] Change find path from otx/recipes to otx Signed-off-by: Kim, Vinnam --- MANIFEST.in | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 229aecc7bd1..de5b4e10000 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ recursive-include requirements * recursive-include otx *.pyx recursive-exclude otx *.c *.html -recursive-include otx/recipes *.yaml +recursive-include otx *.yaml diff --git a/setup.py b/setup.py index 733cdb11fad..69106bb42f3 100644 --- a/setup.py +++ b/setup.py @@ -143,7 +143,7 @@ def find_yaml_recipes(): """Find YAML recipe files in the package.""" results = defaultdict(list) - for root, _, files in os.walk(os.path.join("otx", "recipes")): + for root, _, files in os.walk("otx"): module = ".".join(root.split(os.sep)) for file in files: _, ext = os.path.splitext(file) From 30c9f0db37cf32238112b934e6537e2a7088f96a Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Thu, 16 Feb 2023 10:14:53 +0900 Subject: [PATCH 12/15] Fix MemSizeAction to follow the existing hyperparams parsing rule - Change parsing destination from args.mem_cache_size to args.params.algo_backend.mem_cache_size Signed-off-by: Kim, Vinnam --- otx/cli/manager/config_manager.py | 8 +------- otx/cli/tools/train.py | 1 + otx/cli/utils/parser.py | 3 +++ tests/unit/cli/utils/test_parser.py | 3 ++- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/otx/cli/manager/config_manager.py b/otx/cli/manager/config_manager.py index 5243c983949..bd85c088023 100644 --- a/otx/cli/manager/config_manager.py +++ b/otx/cli/manager/config_manager.py @@ -309,13 +309,7 @@ def get_hyparams_config(self) -> ConfigurableParameters: type_hint = gen_param_help(hyper_parameters) updated_hyper_parameters = gen_params_dict_from_args(self.args, type_hint=type_hint) override_parameters(updated_hyper_parameters, hyper_parameters) - - # (vinnamki) I added this line because these lines above looks like - # it parses out of the arguments, but I'm wondering if this is working. - hyper_parameters = create(hyper_parameters) - hyper_parameters.algo_backend.mem_cache_size = self.args.mem_cache_size - - return hyper_parameters + return create(hyper_parameters) def get_dataset_config(self, subsets: List[str]) -> dict: """Returns dataset_config in a format suitable for each subset. diff --git a/otx/cli/tools/train.py b/otx/cli/tools/train.py index 8705e91a4b1..2e0bd6cbc8e 100644 --- a/otx/cli/tools/train.py +++ b/otx/cli/tools/train.py @@ -115,6 +115,7 @@ def get_args(): parser.add_argument( "--mem-cache-size", action=MemSizeAction, + dest="params.algo_backend.mem_cache_size", type=str, required=False, default=0, diff --git a/otx/cli/utils/parser.py b/otx/cli/utils/parser.py index f6044751806..077d4dad6e7 100644 --- a/otx/cli/utils/parser.py +++ b/otx/cli/utils/parser.py @@ -30,6 +30,9 @@ class MemSizeAction(argparse.Action): def __init__(self, option_strings, dest, nargs=None, **kwargs): if nargs is not None: raise ValueError("nargs not allowed") + expected_dest = "params.algo_backend.mem_cache_size" + if dest != expected_dest: + raise ValueError(f"dest should be {expected_dest}, but dest={dest}.") super().__init__(option_strings, dest, **kwargs) def __call__(self, parser, namespace, values, option_string=None): diff --git a/tests/unit/cli/utils/test_parser.py b/tests/unit/cli/utils/test_parser.py index 003a655aa87..47cd0d9240e 100644 --- a/tests/unit/cli/utils/test_parser.py +++ b/tests/unit/cli/utils/test_parser.py @@ -215,6 +215,7 @@ def fxt_argparse(): parser = ArgumentParser() parser.add_argument( "--mem-cache-size", + dest="params.algo_backend.mem_cache_size", action=MemSizeAction, type=str, required=False, @@ -243,6 +244,6 @@ def fxt_argparse(): def test_mem_size_parsing(fxt_argparse, mem_size_arg, expected): try: args = fxt_argparse.parse_args(["--mem-cache-size", mem_size_arg]) - assert args.mem_cache_size == expected + assert getattr(args, "params.algo_backend.mem_cache_size") == expected except ValueError: assert expected is None From ca7681f57a975e2006548a4ff4c56fb537258bf8 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Thu, 16 Feb 2023 16:55:28 +0900 Subject: [PATCH 13/15] Fix for anomaly and action Signed-off-by: Kim, Vinnam --- .../configs/classification/configuration.yaml | 16 ++++++++ .../configs/detection/configuration.yaml | 16 ++++++++ otx/algorithms/common/tasks/training_base.py | 12 +++++- otx/cli/tools/train.py | 1 - otx/cli/utils/parser.py | 39 ++++++++++++------- .../pipelines/load_image_from_otx_dataset.py | 7 +++- tests/unit/cli/utils/test_parser.py | 6 ++- 7 files changed, 76 insertions(+), 21 deletions(-) diff --git a/otx/algorithms/action/configs/classification/configuration.yaml b/otx/algorithms/action/configs/classification/configuration.yaml index e9c63c445c5..5221eaaa03b 100644 --- a/otx/algorithms/action/configs/classification/configuration.yaml +++ b/otx/algorithms/action/configs/classification/configuration.yaml @@ -261,6 +261,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/algorithms/action/configs/detection/configuration.yaml b/otx/algorithms/action/configs/detection/configuration.yaml index e9c63c445c5..5221eaaa03b 100644 --- a/otx/algorithms/action/configs/detection/configuration.yaml +++ b/otx/algorithms/action/configs/detection/configuration.yaml @@ -261,6 +261,22 @@ algo_backend: value: INCREMENTAL visible_in_ui: True warning: null + mem_cache_size: + affects_outcome_of: TRAINING + default_value: 0 + description: Size of memory pool for caching decoded data to load data faster (bytes). + editable: true + header: Size of memory pool + max_value: 9223372036854775807 + min_value: 0 + type: INTEGER + ui_rules: + action: DISABLE_EDITING + operator: AND + rules: [] + type: UI_RULES + visible_in_ui: false + warning: null type: PARAMETER_GROUP visible_in_ui: true type: CONFIGURABLE_PARAMETERS diff --git a/otx/algorithms/common/tasks/training_base.py b/otx/algorithms/common/tasks/training_base.py index f1b8b702326..3a76c4ed4be 100644 --- a/otx/algorithms/common/tasks/training_base.py +++ b/otx/algorithms/common/tasks/training_base.py @@ -627,9 +627,17 @@ def _find_max_num_workers(cfg: dict): return max(num_workers) - _, world_size = get_dist_info() - mem_cache_size = self.hyperparams.algo_backend.mem_cache_size // world_size + def _get_mem_cache_size(): + if not hasattr(self.hyperparams, "algo_backend"): + return 0 + if not hasattr(self.hyperparams.algo_backend, "mem_cache_size"): + return 0 + + _, world_size = get_dist_info() + return self.hyperparams.algo_backend.mem_cache_size // world_size + max_num_workers = _find_max_num_workers(data_cfg) + mem_cache_size = _get_mem_cache_size() mode = "multiprocessing" if max_num_workers > 0 else "singleprocessing" caching.MemCacheHandlerSingleton.create(mode, mem_cache_size) diff --git a/otx/cli/tools/train.py b/otx/cli/tools/train.py index 2e0bd6cbc8e..5d670e98ee2 100644 --- a/otx/cli/tools/train.py +++ b/otx/cli/tools/train.py @@ -118,7 +118,6 @@ def get_args(): dest="params.algo_backend.mem_cache_size", type=str, required=False, - default=0, help="Size of memory pool for caching decoded data to load data faster. " "For example, you can use digits for bytes size (e.g. 1024) or a string with size units " "(e.g. 7KB = 7 * 2^10, 3MB = 3 * 2^20, and 2GB = 2 * 2^30).", diff --git a/otx/cli/utils/parser.py b/otx/cli/utils/parser.py index 077d4dad6e7..daa79203ccf 100644 --- a/otx/cli/utils/parser.py +++ b/otx/cli/utils/parser.py @@ -111,25 +111,36 @@ def _gen_param_help(prefix: str, cur_params: Dict) -> Dict: def gen_params_dict_from_args(args, type_hint: Optional[dict] = None) -> Dict[str, dict]: """Generates hyper parameters dict from parsed command line arguments.""" + def _get_leaf_node(curr_dict: Dict[str, dict], curr_key: str): + split_key = curr_key.split(".") + node_key = split_key[0] + + if len(split_key) == 1: + # It is leaf node + return curr_dict, node_key + + # Dive deeper + curr_key = ".".join(split_key[1:]) + if node_key not in curr_dict: + curr_dict[node_key] = {} + return _get_leaf_node(curr_dict[node_key], curr_key) + + _prefix = "params." params_dict: Dict[str, dict] = {} for param_name in dir(args): - if not param_name.startswith("params."): + value = getattr(args, param_name) + + if not param_name.startswith(_prefix) or value is None: continue + # param_name.removeprefix(_prefix) + origin_key = param_name[len(_prefix) :] value_type = None - cur_dict = params_dict - split_param_name = param_name.split(".")[1:] - if type_hint: - origin_key = ".".join(split_param_name) - value_type = type_hint[origin_key].get("type", None) - for i, k in enumerate(split_param_name): - if k not in cur_dict: - cur_dict[k] = {} - if i < len(split_param_name) - 1: - cur_dict = cur_dict[k] - else: - value = getattr(args, param_name) - cur_dict[k] = {"value": value_type(value) if value_type else value} + if type_hint is not None: + value_type = type_hint.get(origin_key, {}).get("type", None) + + leaf_node_dict, node_key = _get_leaf_node(params_dict, origin_key) + leaf_node_dict[node_key] = {"value": value_type(value) if value_type else value} return params_dict diff --git a/otx/core/data/pipelines/load_image_from_otx_dataset.py b/otx/core/data/pipelines/load_image_from_otx_dataset.py index 281554d9eb2..e9243cb4bcc 100644 --- a/otx/core/data/pipelines/load_image_from_otx_dataset.py +++ b/otx/core/data/pipelines/load_image_from_otx_dataset.py @@ -59,8 +59,11 @@ def __call__(self, results: Dict[str, Any]): img = img.astype(np.float32) shape = img.shape - assert img.shape[0] == results["height"], f"{img.shape[0]} != {results['height']}" - assert img.shape[1] == results["width"], f"{img.shape[1]} != {results['width']}" + if img.shape[0] != results["height"]: + results["height"] = img.shape[0] + + if img.shape[1] != results["width"]: + results["width"] = img.shape[1] filename = f"Dataset item index {results['index']}" results["filename"] = filename diff --git a/tests/unit/cli/utils/test_parser.py b/tests/unit/cli/utils/test_parser.py index 47cd0d9240e..ca687a73d4c 100644 --- a/tests/unit/cli/utils/test_parser.py +++ b/tests/unit/cli/utils/test_parser.py @@ -82,6 +82,7 @@ def mock_args(mocker): setattr(mock_args, "params.a.c", True) setattr(mock_args, "params.b", "fake") setattr(mock_args, "params.c", 10) + setattr(mock_args, "params.d", None) return mock_args @@ -95,6 +96,7 @@ def test_gen_params_dict_from_args(mock_args): assert param_dict["a"]["c"]["value"] is True assert param_dict["b"]["value"] == "fake" assert param_dict["c"]["value"] == 10 + assert "d" not in param_dict @e2e_pytest_unit @@ -191,7 +193,7 @@ def test_get_parser_and_hprams_data_with_template(mocker, tmp_dir): # check mock_template.assert_called_once() assert hyper_parameters == mock_hyper_parameters - assert params == ["--left-args"] + assert params == [] assert isinstance(parser, ArgumentParser) @@ -206,7 +208,7 @@ def test_get_parser_and_hprams_data(mocker): # check assert hyper_parameters == {} - assert params == ["--left-args"] + assert params == [] assert isinstance(parser, ArgumentParser) From cf8918972de8e1d569c26094ec9bbf669f47c0fc Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Thu, 16 Feb 2023 23:25:07 +0900 Subject: [PATCH 14/15] Fix unit test Signed-off-by: Kim, Vinnam --- otx/core/data/caching/__init__.py | 4 ++-- otx/core/data/caching/mem_cache_handler.py | 8 ++++++-- otx/core/data/pipelines/load_image_from_otx_dataset.py | 9 +++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/otx/core/data/caching/__init__.py b/otx/core/data/caching/__init__.py index 64232048d44..f604a62e843 100644 --- a/otx/core/data/caching/__init__.py +++ b/otx/core/data/caching/__init__.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -from .mem_cache_handler import MemCacheHandlerSingleton +from .mem_cache_handler import MemCacheHandlerError, MemCacheHandlerSingleton from .mem_cache_hook import MemCacheHook -__all__ = ["MemCacheHandlerSingleton", "MemCacheHook"] +__all__ = ["MemCacheHandlerSingleton", "MemCacheHook", "MemCacheHandlerError"] diff --git a/otx/core/data/caching/mem_cache_handler.py b/otx/core/data/caching/mem_cache_handler.py index 351658fa192..22259cb9427 100644 --- a/otx/core/data/caching/mem_cache_handler.py +++ b/otx/core/data/caching/mem_cache_handler.py @@ -146,6 +146,10 @@ def __del__(self): self._manager.shutdown() +class MemCacheHandlerError(Exception): + """Exception class for MemCacheHandler.""" + + class MemCacheHandlerSingleton: """A singleton class to create, delete and get MemCacheHandlerBase.""" @@ -159,7 +163,7 @@ def get(cls) -> MemCacheHandlerBase: """ if not hasattr(cls, "instance"): cls_name = cls.__class__.__name__ - raise RuntimeError(f"Before calling {cls_name}.get(), you should call {cls_name}.create() first.") + raise MemCacheHandlerError(f"Before calling {cls_name}.get(), you should call {cls_name}.create() first.") return cls.instance @@ -181,7 +185,7 @@ def create(cls, mode: str, mem_size: int) -> MemCacheHandlerBase: elif mode == "singleprocessing": cls.instance = MemCacheHandlerForSP(mem_size) else: - raise ValueError(f"{mode} is unknown mode.") + raise MemCacheHandlerError(f"{mode} is unknown mode.") return cls.instance diff --git a/otx/core/data/pipelines/load_image_from_otx_dataset.py b/otx/core/data/pipelines/load_image_from_otx_dataset.py index e9243cb4bcc..355935fae74 100644 --- a/otx/core/data/pipelines/load_image_from_otx_dataset.py +++ b/otx/core/data/pipelines/load_image_from_otx_dataset.py @@ -10,7 +10,7 @@ from otx.algorithms.common.utils.data import get_image from otx.api.utils.argument_checks import check_input_parameters_type -from ..caching import MemCacheHandlerSingleton +from ..caching import MemCacheHandlerError, MemCacheHandlerSingleton _CACHE_DIR = TemporaryDirectory(prefix="img-cache-") # pylint: disable=consider-using-with @@ -33,7 +33,12 @@ class LoadImageFromOTXDataset: @check_input_parameters_type() def __init__(self, to_float32: bool = False): self.to_float32 = to_float32 - self.mem_cache_handler = MemCacheHandlerSingleton.get() + try: + self.mem_cache_handler = MemCacheHandlerSingleton.get() + except MemCacheHandlerError: + # Create a dummy handler + MemCacheHandlerSingleton.create(mode="singleprocessing", mem_size=0) + self.mem_cache_handler = MemCacheHandlerSingleton.get() @staticmethod def _get_unique_key(results: Dict[str, Any]) -> Tuple: From 85496848531e360f24ff740141cdabcd58ca1706 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Mon, 6 Mar 2023 18:41:41 +0900 Subject: [PATCH 15/15] Clean up some code Signed-off-by: Kim, Vinnam --- otx/algorithms/common/tasks/training_base.py | 8 ++------ otx/core/data/caching/mem_cache_handler.py | 12 +++++++++--- otx/core/data/caching/mem_cache_hook.py | 12 ++++++++---- .../data/pipelines/load_image_from_otx_dataset.py | 4 ++-- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/otx/algorithms/common/tasks/training_base.py b/otx/algorithms/common/tasks/training_base.py index e477e7800aa..a58fdf0e20f 100644 --- a/otx/algorithms/common/tasks/training_base.py +++ b/otx/algorithms/common/tasks/training_base.py @@ -25,7 +25,6 @@ import numpy as np import torch -from mmcv.runner import get_dist_info from mmcv.utils.config import Config, ConfigDict from otx.algorithms.common.adapters.mmcv.hooks import OTXLoggerHook @@ -637,13 +636,10 @@ def _find_max_num_workers(cfg: dict): return max(num_workers) def _get_mem_cache_size(): - if not hasattr(self.hyperparams, "algo_backend"): - return 0 if not hasattr(self.hyperparams.algo_backend, "mem_cache_size"): return 0 - _, world_size = get_dist_info() - return self.hyperparams.algo_backend.mem_cache_size // world_size + return self.hyperparams.algo_backend.mem_cache_size max_num_workers = _find_max_num_workers(data_cfg) mem_cache_size = _get_mem_cache_size() @@ -653,5 +649,5 @@ def _get_mem_cache_size(): update_or_add_custom_hook( self._recipe_cfg, - ConfigDict(type="MemCacheHook"), + ConfigDict(type="MemCacheHook", priority="VERY_LOW"), ) diff --git a/otx/core/data/caching/mem_cache_handler.py b/otx/core/data/caching/mem_cache_handler.py index 22259cb9427..44cab53e051 100644 --- a/otx/core/data/caching/mem_cache_handler.py +++ b/otx/core/data/caching/mem_cache_handler.py @@ -9,6 +9,7 @@ from typing import Any, Dict, Optional, Union import numpy as np +from mmcv.runner import get_dist_info from multiprocess.synchronize import Lock from otx.mpa.utils.logger import get_logger @@ -172,13 +173,18 @@ def create(cls, mode: str, mem_size: int) -> MemCacheHandlerBase: """Create a new MemCacheHandlerBase instance. Args: - mode (str): There are two options: multiprocessing or singleprocessing. + mode (str): There are two options: null, multiprocessing or singleprocessing. mem_size (int): The size of memory pool (bytes). """ logger.info(f"Try to create a {mem_size} size memory pool.") - if mem_size == 0: - cls.instance = MemCacheHandlerBase(mem_size) + _, world_size = get_dist_info() + if world_size > 1: + mem_size = mem_size // world_size + logger.info(f"Since world_size={world_size} > 1, each worker a {mem_size} size memory pool.") + + if mode == "null" or mem_size == 0: + cls.instance = MemCacheHandlerBase(mem_size=0) cls.instance.freeze() elif mode == "multiprocessing": cls.instance = MemCacheHandlerForMP(mem_size) diff --git a/otx/core/data/caching/mem_cache_hook.py b/otx/core/data/caching/mem_cache_hook.py index 1dc1462ae6d..ecd48fa840a 100644 --- a/otx/core/data/caching/mem_cache_hook.py +++ b/otx/core/data/caching/mem_cache_hook.py @@ -14,16 +14,20 @@ class MemCacheHook(Hook): def __init__(self) -> None: self.handler = MemCacheHandlerSingleton.get() - - def before_run(self, runner): - """Before run, freeze the handler.""" + # It is because the first evaluation comes at the very beginning of the training. + # We don't want to cache validation samples first. self.handler.freeze() def before_epoch(self, runner): """Before training, unfreeze the handler.""" + # We want to cache training samples first. self.handler.unfreeze() def after_epoch(self, runner): - """After epoch. Log the handler statistics.""" + """After epoch. Log the handler statistics. + + To prevent it from skipping the validation samples, + this hook should have lower priority than CustomEvalHook. + """ self.handler.freeze() runner.logger.info(f"{self.handler}") diff --git a/otx/core/data/pipelines/load_image_from_otx_dataset.py b/otx/core/data/pipelines/load_image_from_otx_dataset.py index 355935fae74..100c24c7e25 100644 --- a/otx/core/data/pipelines/load_image_from_otx_dataset.py +++ b/otx/core/data/pipelines/load_image_from_otx_dataset.py @@ -36,8 +36,8 @@ def __init__(self, to_float32: bool = False): try: self.mem_cache_handler = MemCacheHandlerSingleton.get() except MemCacheHandlerError: - # Create a dummy handler - MemCacheHandlerSingleton.create(mode="singleprocessing", mem_size=0) + # Create a null handler + MemCacheHandlerSingleton.create(mode="null", mem_size=0) self.mem_cache_handler = MemCacheHandlerSingleton.get() @staticmethod