From 34788b50156a2b8e8ec09bd6b791a89a0ca58f1e Mon Sep 17 00:00:00 2001 From: eunwoosh Date: Fri, 6 Oct 2023 17:48:55 +0900 Subject: [PATCH] release mem cache handler after training is done --- src/otx/algorithms/classification/task.py | 3 +++ src/otx/algorithms/detection/task.py | 3 +++ src/otx/algorithms/segmentation/task.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/src/otx/algorithms/classification/task.py b/src/otx/algorithms/classification/task.py index 03da102188f..1a811ab7234 100644 --- a/src/otx/algorithms/classification/task.py +++ b/src/otx/algorithms/classification/task.py @@ -80,6 +80,7 @@ from otx.api.utils.dataset_utils import add_saliency_maps_to_dataset_item from otx.api.utils.labels_utils import get_empty_label from otx.cli.utils.multi_gpu import is_multigpu_child_process +from otx.core.data.caching.mem_cache_handler import MemCacheHandlerSingleton logger = get_logger() RECIPE_TRAIN_TYPE = { @@ -215,6 +216,8 @@ def train( results = self._train_model(dataset) + MemCacheHandlerSingleton.delete() + # Check for stop signal when training has stopped. If should_stop is true, training was cancelled and no new if self._should_stop: logger.info("Training cancelled.") diff --git a/src/otx/algorithms/detection/task.py b/src/otx/algorithms/detection/task.py index 9a3d9cca885..d87ce125f38 100644 --- a/src/otx/algorithms/detection/task.py +++ b/src/otx/algorithms/detection/task.py @@ -65,6 +65,7 @@ from otx.api.usecases.tasks.interfaces.export_interface import ExportType from otx.api.utils.dataset_utils import add_saliency_maps_to_dataset_item from otx.cli.utils.multi_gpu import is_multigpu_child_process +from otx.core.data.caching.mem_cache_handler import MemCacheHandlerSingleton logger = get_logger() @@ -231,6 +232,8 @@ def train( val_dataset.purpose = DatasetPurpose.INFERENCE val_preds, val_map = self._infer_model(val_dataset, InferenceParameters(is_evaluation=True)) + MemCacheHandlerSingleton.delete() + preds_val_dataset = val_dataset.with_empty_annotations() if self._hyperparams.postprocessing.result_based_confidence_threshold: confidence_threshold = 0.0 # Use all predictions to compute best threshold diff --git a/src/otx/algorithms/segmentation/task.py b/src/otx/algorithms/segmentation/task.py index a62270bb13c..779bdd10edc 100644 --- a/src/otx/algorithms/segmentation/task.py +++ b/src/otx/algorithms/segmentation/task.py @@ -70,6 +70,7 @@ create_hard_prediction_from_soft_prediction, ) from otx.cli.utils.multi_gpu import is_multigpu_child_process +from otx.core.data.caching.mem_cache_handler import MemCacheHandlerSingleton logger = get_logger() RECIPE_TRAIN_TYPE = { @@ -171,6 +172,8 @@ def train( results = self._train_model(dataset) + MemCacheHandlerSingleton.delete() + # Check for stop signal when training has stopped. If should_stop is true, training was cancelled and no new if self._should_stop: logger.info("Training cancelled.")