Skip to content

Commit

Permalink
Merge pull request #1392 from openvinotoolkit/multigpu_enablement
Browse files Browse the repository at this point in the history
[OTX] Enable multi-GPU training
  • Loading branch information
eunwoosh authored Dec 28, 2022
2 parents 0671efe + a4fa37c commit a31c064
Show file tree
Hide file tree
Showing 20 changed files with 372 additions and 41 deletions.
4 changes: 2 additions & 2 deletions otx/algorithms/action/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ class ActionInferenceTask(BaseTask, IInferenceTask, IExportTask, IEvaluationTask
"""Inference Task Implementation of OTX Action Task."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
# self._should_stop = False
self._model = None
self.task_environment = task_environment
super().__init__(ActionConfig, task_environment)
super().__init__(ActionConfig, task_environment, **kwargs)

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(
Expand Down
5 changes: 3 additions & 2 deletions otx/algorithms/anomaly/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@
class InferenceTask(IInferenceTask, IEvaluationTask, IExportTask, IUnload):
"""Base Anomaly Task."""

def __init__(self, task_environment: TaskEnvironment) -> None:
def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] = None) -> None:
"""Train, Infer, Export, Optimize and Deploy an Anomaly Classification Task.
Args:
task_environment (TaskEnvironment): OTX Task environment.
output_path (Optional[str]): output path where task output are saved.
"""
torch.backends.cudnn.enabled = True
logger.info("Initializing the task environment.")
Expand All @@ -87,7 +88,7 @@ def __init__(self, task_environment: TaskEnvironment) -> None:
self.base_dir = os.path.abspath(os.path.dirname(template_file_path))

# Hyperparameters.
self.project_path: str = tempfile.mkdtemp(prefix="otx-anomalib")
self.project_path: str = output_path if output_path is not None else tempfile.mkdtemp(prefix="otx-anomalib")
self.config = self.get_config()

# Set default model attributes.
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/anomaly/tasks/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@
class NNCFTask(InferenceTask, IOptimizationTask):
"""Base Anomaly Task."""

def __init__(self, task_environment: TaskEnvironment) -> None:
def __init__(self, task_environment: TaskEnvironment, **kwargs) -> None:
"""Task for compressing models using NNCF.
Args:
task_environment (TaskEnvironment): OTX Task environment.
"""
self.compression_ctrl = None
self.nncf_preset = "nncf_quantization"
super().__init__(task_environment)
super().__init__(task_environment, **kwargs)
self.optimization_type = ModelOptimizationType.NNCF

def _set_attributes_by_hyperparams(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ class ClassificationInferenceTask(
task_environment: TaskEnvironment

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] = None):
logger.info("Loading ClassificationTask.")
self._scratch_space = tempfile.mkdtemp(prefix="otx-cls-scratch-")
if output_path is None:
output_path = tempfile.mkdtemp(prefix="otx-cls-scratch-")
self._scratch_space = output_path
logger.info(f"Scratch space created at {self._scratch_space}")

self._task_environment = task_environment
Expand Down Expand Up @@ -457,7 +459,7 @@ class ClassificationNNCFTask(
): # pylint: disable=too-many-instance-attributes
"""Task for compressing classification models using NNCF."""

def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
curr_model_path = task_environment.model_template.model_template_path
base_model_path = os.path.join(
os.path.dirname(os.path.abspath(curr_model_path)),
Expand All @@ -468,7 +470,7 @@ def __init__(self, task_environment: TaskEnvironment):
# Redirect to base model
task_environment.model_template = parse_model_template(base_model_path)
logger.info("Loading ClassificationNNCFTask.")
super().__init__(task_environment)
super().__init__(task_environment, **kwargs)

check_nncf_is_enabled()

Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/classification/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ class ClassificationInferenceTask(
): # pylint: disable=too-many-instance-attributes
"""Inference Task Implementation of OTX Classification."""

def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
self._should_stop = False
super().__init__(TASK_CONFIG, task_environment)
super().__init__(TASK_CONFIG, task_environment, **kwargs)

self._task_environment = task_environment
if len(task_environment.get_labels(False)) == 1:
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/classification/tasks/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@
class OTXClassificationNNCFTask(ClassificationNNCFTask):
"""Task for compressing classification models using NNCF."""

def __init__(self, task_environment: TaskEnvironment): # pylint: disable=useless-parent-delegation
super().__init__(task_environment)
def __init__(self, task_environment: TaskEnvironment, **kwargs): # pylint: disable=useless-parent-delegation
super().__init__(task_environment, **kwargs)
5 changes: 3 additions & 2 deletions otx/algorithms/common/adapters/mmcv/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# * https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/epoch_based_runner.py
# * https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py

import time
import warnings
from typing import List, Optional, Sequence

Expand Down Expand Up @@ -75,8 +76,8 @@ def train(self, data_loader: DataLoader, **kwargs):
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook("before_train_epoch")
# TODO: uncomment below line or resolve root cause of deadlock issue if multi-GPUs need to be supported.
# time.sleep(2) # Prevent possible multi-gpu deadlock during epoch transition
if self.distributed:
time.sleep(2) # Prevent possible multi-gpu deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook("before_train_iter")
Expand Down
8 changes: 5 additions & 3 deletions otx/algorithms/common/tasks/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,24 @@ class BaseTask(IInferenceTask, IExportTask, IEvaluationTask, IUnload):
_task_environment: TaskEnvironment

@check_input_parameters_type()
def __init__(self, task_config, task_environment: TaskEnvironment):
def __init__(self, task_config, task_environment: TaskEnvironment, output_path: Optional[str] = None):
self._task_config = task_config
self._task_environment = task_environment
self._hyperparams = task_environment.get_hyper_parameters(self._task_config) # type: ConfigDict
self._model_name = task_environment.model_template.name
self._task_type = task_environment.model_template.task_type
self._labels = task_environment.get_labels(include_empty=False)
self._output_path = tempfile.mkdtemp(prefix="OTX-task-")
logger.info(f"created output path at {self._output_path}")
self.confidence_threshold = self._get_confidence_threshold(self._hyperparams)
# Set default model attributes.
self._model_label_schema = [] # type: List[LabelEntity]
self._optimization_methods = [] # type: List[OptimizationMethod]
self._model_ckpt = None
self._data_pipeline_path = None
self._anchors = {} # type: Dict[str, int]
if output_path is None:
output_path = tempfile.mkdtemp(prefix="OTX-task-")
self._output_path = output_path
logger.info(f"created output path at {self._output_path}")
if task_environment.model is not None:
logger.info("loading the model from the task env.")
state_dict = self._load_model_state_dict(self._task_environment.model)
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/detection/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ class DetectionInferenceTask(BaseTask, IInferenceTask, IExportTask, IEvaluationT
"""Inference Task Implementation of OTX Detection."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
# self._should_stop = False
self.train_type = None
super().__init__(DetectionConfig, task_environment)
super().__init__(DetectionConfig, task_environment, **kwargs)
self.template_dir = os.path.abspath(os.path.dirname(self.template_file_path))
self.base_dir = self.template_dir
# TODO Move this to the common
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/detection/tasks/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class DetectionNNCFTask(DetectionInferenceTask, IOptimizationTask):
"""Task for compressing detection models using NNCF."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
super().__init__(task_environment)
def __init__(self, task_environment: TaskEnvironment, **kwargs):
super().__init__(task_environment, **kwargs)
self._val_dataloader = None
self._compression_ctrl = None
self._nncf_preset = "nncf_quantization"
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/segmentation/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ class SegmentationInferenceTask(BaseTask, IInferenceTask, IExportTask, IEvaluati
"""Inference Task Implementation of OTX Segmentation."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
# self._should_stop = False
self.freeze = True
self.metric = "mDice"
self._label_dictionary = {} # type: Dict
super().__init__(SegmentationConfig, task_environment)
super().__init__(SegmentationConfig, task_environment, **kwargs)

def infer(
self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/segmentation/tasks/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class SegmentationNNCFTask(SegmentationInferenceTask, IOptimizationTask):
"""Task for compressing object detection models using NNCF."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
super().__init__(task_environment)
def __init__(self, task_environment: TaskEnvironment, **kwargs):
super().__init__(task_environment, **kwargs)

self._val_dataloader = None
self._compression_ctrl = None
Expand Down
41 changes: 30 additions & 11 deletions otx/cli/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import argparse
import os
import shutil

from otx.api.configuration.helper import create
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import ModelEntity
from otx.api.entities.model_template import TaskType
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
Expand All @@ -38,6 +38,7 @@
read_label_schema,
save_model_data,
)
from otx.cli.utils.multi_gpu import MultiGPUManager
from otx.cli.utils.parser import (
add_hyper_parameters_sub_parser,
gen_params_dict_from_args,
Expand Down Expand Up @@ -113,9 +114,9 @@ def parse_args():
help="Location where trained model will be stored.",
)
parser.add_argument(
"--save-logs-to",
"--work-dir",
required=False,
help="Location where logs will be stored.",
help="Location where the intermediate output of the training will be stored.",
)
parser.add_argument(
"--enable-hpo",
Expand All @@ -128,6 +129,18 @@ def parse_args():
type=float,
help="Expected ratio of total time to run HPO to time taken for full fine-tuning.",
)
parser.add_argument(
"--gpus",
type=str,
help="Comma-separated indices of GPU. \
If there are more than one available GPU, then model is trained with multi GPUs.",
)
parser.add_argument(
"--multi-gpu-port",
default=25000,
type=int,
help="port for communication beteween multi GPU processes.",
)

add_hyper_parameters_sub_parser(parser, hyper_parameters)

Expand All @@ -136,7 +149,6 @@ def parse_args():

def main():
"""Main function that is used for model training."""

# Dynamically create an argument parser based on override parameters.
args, template, hyper_parameters = parse_args()
# Get new values from user's input.
Expand Down Expand Up @@ -195,9 +207,19 @@ def main():
task = run_hpo(args, environment, dataset, template.task_type)
if task is None:
print("cannot run HPO for this task. will train a model without HPO.")
task = task_class(task_environment=environment)
task = task_class(task_environment=environment, output_path=args.work_dir)
else:
task = task_class(task_environment=environment)
task = task_class(task_environment=environment, output_path=args.work_dir)

if args.gpus:
multigpu_manager = MultiGPUManager(main, args.gpus, str(args.multi_gpu_port))
if template.task_type in (TaskType.ACTION_CLASSIFICATION, TaskType.ACTION_DETECTION):
print("Multi-GPU training for action tasks isn't supported yet. A single GPU will be used for a training.")
elif (
multigpu_manager.is_available()
and not template.task_type.is_anomaly # anomaly tasks don't use this way for multi-GPU training
):
multigpu_manager.setup_multi_gpu_train(task.project_path, hyper_parameters if args.enable_hpo else None)

output_model = ModelEntity(dataset, environment.get_model_configuration())

Expand All @@ -223,11 +245,8 @@ def main():
assert resultset.performance is not None
print(resultset.performance)

if args.save_logs_to:
tmp_path = task.project_path
logs_path = os.path.join(args.save_logs_to, tmp_path.split("/")[-1])
shutil.copytree(tmp_path, logs_path)
print(f"Save logs: {logs_path}")
if args.gpus:
multigpu_manager.finalize()


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def run_hpo(args: Namespace, environment: TaskEnvironment, dataset: DatasetEntit
task_class = get_impl_class(environment.model_template.entrypoints.base)
task_class = get_train_wrapper_task(task_class, task_type)

task = task_class(task_environment=environment)
task = task_class(task_environment=environment, output_path=args.work_dir)

hpopt_cfg = _load_hpopt_config(
osp.join(
Expand Down Expand Up @@ -335,8 +335,8 @@ def get_train_wrapper_task(impl_class, task_type):
class HpoTrainTask(impl_class):
"""wrapper class for the HPO."""

def __init__(self, task_environment):
super().__init__(task_environment)
def __init__(self, task_environment, **kwargs):
super().__init__(task_environment, **kwargs)
self._task_type = task_type

# TODO: need to check things below whether works on MPA tasks
Expand Down
Loading

0 comments on commit a31c064

Please sign in to comment.