Skip to content

Commit

Permalink
align with pre-commit test
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Dec 15, 2022
1 parent 2b8741c commit ef4eadd
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 22 deletions.
2 changes: 1 addition & 1 deletion otx/algorithms/action/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, task_environment: TaskEnvironment, **kwargs):
# self._should_stop = False
self._model = None
self.task_environment = task_environment
super().__init__(ActionConfig, task_environment **kwargs)
super().__init__(ActionConfig, task_environment, **kwargs)

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(
Expand Down
1 change: 1 addition & 0 deletions otx/algorithms/anomaly/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
Args:
task_environment (TaskEnvironment): OTX Task environment.
output_path (Optional[str]): output path where task output are saved.
"""
torch.backends.cudnn.enabled = True
logger.info("Initializing the task environment.")
Expand Down
2 changes: 1 addition & 1 deletion otx/algorithms/common/adapters/mmcv/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
# * https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/epoch_based_runner.py
# * https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py

import warnings
import time
import warnings
from typing import List, Optional, Sequence

import mmcv
Expand Down
102 changes: 84 additions & 18 deletions otx/cli/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from otx.api.configuration.helper import create
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import ModelEntity
from otx.api.entities.model_template import ModelTemplate, TaskType
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
Expand Down Expand Up @@ -140,7 +141,8 @@ def parse_args():
parser.add_argument(
"--gpus",
type=str,
help="Comma-separated indcies of GPU. If there are more than one available GPU, then model is trained with multi GPUs.",
help="Comma-separated indcies of GPU. \
If there are more than one available GPU, then model is trained with multi GPUs.",
)
parser.add_argument(
"--multi-gpu-port",
Expand Down Expand Up @@ -251,13 +253,27 @@ def main():


class MultiGPUManager:
"""Class to manage multi-GPU training."""

def __init__(self, gpu_ids: str, multi_gpu_port: str):
self._gpu_ids = self._get_gpu_ids(gpu_ids)
self._gpu_ids = self.get_gpu_ids(gpu_ids)
self._multi_gpu_port = multi_gpu_port
self._main_pid = os.getpid()
self._processes = None
self._processes: Optional[List[mp.Process]] = None

@staticmethod
def get_gpu_ids(gpus: str) -> List[int]:
"""Get proper GPU indices form `--gpu` arguments.
Given `--gpus` argument, exclude inappropriate indices and transform to list of int format.
Args:
gpus (str): GPU indices to use. Format should be Comma-separated indices.
def _get_gpu_ids(self, gpus: str) -> List[int]:
Returns:
List[int]:
list including proper GPU indices.
"""
num_available_gpu = torch.cuda.device_count()
gpu_ids = []
for gpu_id in gpus.split(","):
Expand All @@ -278,13 +294,32 @@ def _get_gpu_ids(self, gpus: str) -> List[int]:

return gpu_ids

def is_available(self, template) -> bool:
return len(self._gpu_ids) > 1 and not template.task_type.is_anomaly
def is_available(self, template: ModelTemplate) -> bool:
"""Check multi GPU training is available.
Args:
template (ModelTemplate): template for training.
Returns:
bool:
whether multi GPU training is available.
"""
return (
len(self._gpu_ids) > 1
and not template.task_type.is_anomaly
and template.task_type not in (TaskType.ACTION_CLASSIFICATION, TaskType.ACTION_DETECTION)
)

def setup_multi_gpu_train(
self, output_path: str, optimized_hyper_parameters: Optional[ConfigurableParameters] = None
):
if optimized_hyper_parameters is not None:
"""Carry out what should be done to run multi GPU training.
Args:
output_path (str): output path where task output are saved.
optimized_hyper_parameters (ConfigurableParameters or None): hyper parameters reflecting HPO result.
"""
if optimized_hyper_parameters is not None: # if HPO is executed, optimized HPs are applied to child processes
self._set_optimized_hp_for_child_process(optimized_hyper_parameters)

self._processes = self._spawn_multi_gpu_processes(output_path)
Expand All @@ -294,16 +329,23 @@ def setup_multi_gpu_train(

self.initialize_multigpu_train(0, self._gpu_ids, self._multi_gpu_port)

t = threading.Thread(target=self._check_child_processes_alive, daemon=True)
t.start()
threading.Thread(target=self._check_child_processes_alive, daemon=True).start()

def finalize(self):
"""Join all child processes."""
if self._processes is not None:
for p in self._processes:
p.join()

@staticmethod
def initialize_multigpu_train(rank: int, gpu_ids: List[int], multi_gpu_port: str):
"""Initilization for multi GPU training.
Args:
rank (int): index of multi GPU processes.
gpu_ids (List[int]): list including which GPU indeces will be used.
multi_gpu_port (str): port for communication between multi GPU processes.
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = multi_gpu_port
torch.cuda.set_device(gpu_ids[rank])
Expand All @@ -312,6 +354,14 @@ def initialize_multigpu_train(rank: int, gpu_ids: List[int], multi_gpu_port: str

@staticmethod
def run_child_process(rank: int, gpu_ids: List[int], output_path: str, multi_gpu_port: str):
"""Function for multi GPU child process to execute.
Args:
rank (int): index of multi GPU processes.
gpu_ids (List[int]): list including which GPU indeces will be used.
output_path (str): output path where task output are saved.
multi_gpu_port (str): port for communication between multi GPU processes.
"""
gpus_arg_idx = sys.argv.index("--gpus")
for _ in range(2):
sys.argv.pop(gpus_arg_idx)
Expand All @@ -324,17 +374,29 @@ def run_child_process(rank: int, gpu_ids: List[int], output_path: str, multi_gpu
main()

@staticmethod
def set_arguments_to_argv(key: str, value: str, after_params: bool = False):
def set_arguments_to_argv(key: str, value: Optional[str] = None, after_params: bool = False):
"""Add arguments at proper position in `sys.argv`.
Args:
key (str): arguement key.
value (str or None): argument value.
after_params (bool): whether argument should be after `param` or not.
"""
if key in sys.argv:
sys.argv[sys.argv.index(key) + 1] = value
if value is not None:
sys.argv[sys.argv.index(key) + 1] = value
else:
if not after_params and "params" in sys.argv:
sys.argv.insert(sys.argv.index("params"), key)
sys.argv.insert(sys.argv.index("params"), value)
if value is not None:
sys.argv.insert(sys.argv.index("params"), value)
else:
if after_params and "params" not in sys.argv:
sys.argv.append("params")
sys.argv.extend([key, value])
if value is not None:
sys.argv.extend([key, value])
else:
sys.argv.append(key)

def _spawn_multi_gpu_processes(self, output_path: str) -> List[mp.Process]:
processes = []
Expand All @@ -348,8 +410,8 @@ def _spawn_multi_gpu_processes(self, output_path: str) -> List[mp.Process]:

return processes

def _terminate_signal_handler(self, signum, frame):
# This code prevents child processses from being killed unintentionally by forked main process
def _terminate_signal_handler(self, signum, _frame):
# This code prevents child processses from being killed unintentionally by proccesses forked from main process
if self._main_pid != os.getpid():
sys.exit()

Expand All @@ -368,15 +430,19 @@ def _kill_child_process(self):
print(f"Kill child process {process.pid}")
try:
process.kill()
except Exception:
except Exception: # pylint: disable=broad-except
pass

def _set_optimized_hp_for_child_process(self, hyper_parameters: ConfigurableParameters):
self.set_arguments_to_argv(
"--learning_parameters.learning_rate", str(hyper_parameters.learning_parameters.learning_rate), True
"--learning_parameters.learning_rate",
str(hyper_parameters.learning_parameters.learning_rate), # type: ignore[attr-defined]
True,
)
self.set_arguments_to_argv(
"--learning_parameters.batch_size", str(hyper_parameters.learning_parameters.batch_size), True
"--learning_parameters.batch_size",
str(hyper_parameters.learning_parameters.batch_size), # type: ignore[attr-defined]
True,
)

def _check_child_processes_alive(self):
Expand Down
4 changes: 2 additions & 2 deletions otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def get_train_wrapper_task(impl_class, task_type):
class HpoTrainTask(impl_class):
"""wrapper class for the HPO."""

def __init__(self, task_environment):
super().__init__(task_environment)
def __init__(self, task_environment, **kwargs):
super().__init__(task_environment, **kwargs)
self._task_type = task_type

# TODO: need to check things below whether works on MPA tasks
Expand Down

0 comments on commit ef4eadd

Please sign in to comment.