From c8e7d592796da43de4ec11eae68ff50f271606d1 Mon Sep 17 00:00:00 2001 From: eunwoosh Date: Fri, 25 Nov 2022 21:09:04 +0900 Subject: [PATCH] make all processs have same output path --- otx/algorithms/common/tasks/training_base.py | 12 +++++++++--- otx/cli/tools/train.py | 9 +++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/otx/algorithms/common/tasks/training_base.py b/otx/algorithms/common/tasks/training_base.py index 8ef1d6f755d..9fccb4b1f55 100644 --- a/otx/algorithms/common/tasks/training_base.py +++ b/otx/algorithms/common/tasks/training_base.py @@ -61,15 +61,17 @@ def __init__(self, task_config, task_environment: TaskEnvironment): self._model_name = task_environment.model_template.name self._task_type = task_environment.model_template.task_type self._labels = task_environment.get_labels(include_empty=False) - output_path = getattr(task_environment, "work_dir", None) - self._output_path = tempfile.mkdtemp(prefix="OTX-task-") if output_path is None else output_path - logger.info(f"created output path at {self._output_path}") self.confidence_threshold = self._get_confidence_threshold(self._hyperparams) # Set default model attributes. self._model_label_schema = [] # type: List[LabelEntity] self._optimization_methods = [] # type: List[OptimizationMethod] self._model_ckpt = None self._anchors = {} # type: Dict[str, int] + output_path = os.environ.get("OTX_TASK_OUTPUT_PATH") + if output_path is None: + output_path = tempfile.mkdtemp(prefix="OTX-task-") + self._output_path = output_path + logger.info(f"created output path at {self._output_path}") if task_environment.model is not None: logger.info("loading the model from the task env.") state_dict = self._load_model_state_dict(self._task_environment.model) @@ -98,6 +100,10 @@ def __init__(self, task_config, task_environment: TaskEnvironment): # to override configuration at runtime self.override_configs = {} # type: Dict[str, str] + @property + def output_path(self): + return self._output_path + def _run_task(self, stage_module, mode=None, dataset=None, **kwargs): # FIXME: Temporary remedy for CVS-88098 export = kwargs.get("export", False) diff --git a/otx/cli/tools/train.py b/otx/cli/tools/train.py index 16d58d5d692..d0d27b37225 100644 --- a/otx/cli/tools/train.py +++ b/otx/cli/tools/train.py @@ -217,10 +217,10 @@ def main(): if args.multi_gpu_train and not template.task_type.is_anomaly: gpu_ids = get_gpu_ids() if len(gpu_ids) > 1: + multi_gpu_train_args = [gpu_ids, task.output_path] if args.enable_hpo: - multi_gpu_processes = run_multi_gpu_train(gpu_ids, hyper_parameters) - else: - multi_gpu_processes = run_multi_gpu_train(gpu_ids) + multi_gpu_train_args.append(hyper_parameters) + multi_gpu_processes = run_multi_gpu_train(*multi_gpu_train_args) else: print("Number of avilable gpu is lower than 2. Multi GPU training won't be executed.") @@ -297,10 +297,11 @@ def terminate_signal_handler(signum, frame, processes: List[mp.Process]): sys.exit(1) -def run_multi_gpu_train(gpu_ids: List[int], optimized_hyper_parameters=None): +def run_multi_gpu_train(gpu_ids: List[int], output_path: str, optimized_hyper_parameters=None): if optimized_hyper_parameters is not None: set_optimized_hp_for_child_process(optimized_hyper_parameters) + os.environ['OTX_TASK_OUTPUT_PATH'] = output_path processes= [] spawned_mp = mp.get_context("spawn") for rank in range(1, len(gpu_ids)):