Skip to content

Commit

Permalink
make all processs have same output path
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Dec 8, 2022
1 parent b92ab94 commit c8e7d59
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
12 changes: 9 additions & 3 deletions otx/algorithms/common/tasks/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,17 @@ def __init__(self, task_config, task_environment: TaskEnvironment):
self._model_name = task_environment.model_template.name
self._task_type = task_environment.model_template.task_type
self._labels = task_environment.get_labels(include_empty=False)
output_path = getattr(task_environment, "work_dir", None)
self._output_path = tempfile.mkdtemp(prefix="OTX-task-") if output_path is None else output_path
logger.info(f"created output path at {self._output_path}")
self.confidence_threshold = self._get_confidence_threshold(self._hyperparams)
# Set default model attributes.
self._model_label_schema = [] # type: List[LabelEntity]
self._optimization_methods = [] # type: List[OptimizationMethod]
self._model_ckpt = None
self._anchors = {} # type: Dict[str, int]
output_path = os.environ.get("OTX_TASK_OUTPUT_PATH")
if output_path is None:
output_path = tempfile.mkdtemp(prefix="OTX-task-")
self._output_path = output_path
logger.info(f"created output path at {self._output_path}")
if task_environment.model is not None:
logger.info("loading the model from the task env.")
state_dict = self._load_model_state_dict(self._task_environment.model)
Expand Down Expand Up @@ -98,6 +100,10 @@ def __init__(self, task_config, task_environment: TaskEnvironment):
# to override configuration at runtime
self.override_configs = {} # type: Dict[str, str]

@property
def output_path(self):
return self._output_path

def _run_task(self, stage_module, mode=None, dataset=None, **kwargs):
# FIXME: Temporary remedy for CVS-88098
export = kwargs.get("export", False)
Expand Down
9 changes: 5 additions & 4 deletions otx/cli/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def main():
if args.multi_gpu_train and not template.task_type.is_anomaly:
gpu_ids = get_gpu_ids()
if len(gpu_ids) > 1:
multi_gpu_train_args = [gpu_ids, task.output_path]
if args.enable_hpo:
multi_gpu_processes = run_multi_gpu_train(gpu_ids, hyper_parameters)
else:
multi_gpu_processes = run_multi_gpu_train(gpu_ids)
multi_gpu_train_args.append(hyper_parameters)
multi_gpu_processes = run_multi_gpu_train(*multi_gpu_train_args)
else:
print("Number of avilable gpu is lower than 2. Multi GPU training won't be executed.")

Expand Down Expand Up @@ -297,10 +297,11 @@ def terminate_signal_handler(signum, frame, processes: List[mp.Process]):

sys.exit(1)

def run_multi_gpu_train(gpu_ids: List[int], optimized_hyper_parameters=None):
def run_multi_gpu_train(gpu_ids: List[int], output_path: str, optimized_hyper_parameters=None):
if optimized_hyper_parameters is not None:
set_optimized_hp_for_child_process(optimized_hyper_parameters)

os.environ['OTX_TASK_OUTPUT_PATH'] = output_path
processes= []
spawned_mp = mp.get_context("spawn")
for rank in range(1, len(gpu_ids)):
Expand Down

0 comments on commit c8e7d59

Please sign in to comment.