Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OTX] Enable multi-GPU training #1392

Merged
merged 23 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions otx/algorithms/action/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ class ActionInferenceTask(BaseTask, IInferenceTask, IExportTask, IEvaluationTask
"""Inference Task Implementation of OTX Action Task."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
# self._should_stop = False
self._model = None
self.task_environment = task_environment
super().__init__(ActionConfig, task_environment)
super().__init__(ActionConfig, task_environment, **kwargs)

@check_input_parameters_type({"dataset": DatasetParamTypeCheck})
def infer(
Expand Down
5 changes: 3 additions & 2 deletions otx/algorithms/anomaly/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@
class InferenceTask(IInferenceTask, IEvaluationTask, IExportTask, IUnload):
"""Base Anomaly Task."""

def __init__(self, task_environment: TaskEnvironment) -> None:
def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] = None) -> None:
harimkang marked this conversation as resolved.
Show resolved Hide resolved
"""Train, Infer, Export, Optimize and Deploy an Anomaly Classification Task.

Args:
task_environment (TaskEnvironment): OTX Task environment.
output_path (Optional[str]): output path where task output are saved.
"""
torch.backends.cudnn.enabled = True
logger.info("Initializing the task environment.")
Expand All @@ -87,7 +88,7 @@ def __init__(self, task_environment: TaskEnvironment) -> None:
self.base_dir = os.path.abspath(os.path.dirname(template_file_path))

# Hyperparameters.
self.project_path: str = tempfile.mkdtemp(prefix="otx-anomalib")
self.project_path: str = output_path if output_path is not None else tempfile.mkdtemp(prefix="otx-anomalib")
self.config = self.get_config()

# Set default model attributes.
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/anomaly/tasks/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@
class NNCFTask(InferenceTask, IOptimizationTask):
"""Base Anomaly Task."""

def __init__(self, task_environment: TaskEnvironment) -> None:
def __init__(self, task_environment: TaskEnvironment, **kwargs) -> None:
"""Task for compressing models using NNCF.

Args:
task_environment (TaskEnvironment): OTX Task environment.
"""
self.compression_ctrl = None
self.nncf_preset = "nncf_quantization"
super().__init__(task_environment)
super().__init__(task_environment, **kwargs)
self.optimization_type = ModelOptimizationType.NNCF

def _set_attributes_by_hyperparams(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ class ClassificationInferenceTask(
task_environment: TaskEnvironment

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str] = None):
logger.info("Loading ClassificationTask.")
self._scratch_space = tempfile.mkdtemp(prefix="otx-cls-scratch-")
if output_path is None:
output_path = tempfile.mkdtemp(prefix="otx-cls-scratch-")
self._scratch_space = output_path
logger.info(f"Scratch space created at {self._scratch_space}")

self._task_environment = task_environment
Expand Down Expand Up @@ -457,7 +459,7 @@ class ClassificationNNCFTask(
): # pylint: disable=too-many-instance-attributes
"""Task for compressing classification models using NNCF."""

def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
curr_model_path = task_environment.model_template.model_template_path
base_model_path = os.path.join(
os.path.dirname(os.path.abspath(curr_model_path)),
Expand All @@ -468,7 +470,7 @@ def __init__(self, task_environment: TaskEnvironment):
# Redirect to base model
task_environment.model_template = parse_model_template(base_model_path)
logger.info("Loading ClassificationNNCFTask.")
super().__init__(task_environment)
super().__init__(task_environment, **kwargs)

check_nncf_is_enabled()

Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/classification/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ class ClassificationInferenceTask(
): # pylint: disable=too-many-instance-attributes
"""Inference Task Implementation of OTX Classification."""

def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
self._should_stop = False
super().__init__(TASK_CONFIG, task_environment)
super().__init__(TASK_CONFIG, task_environment, **kwargs)

self._task_environment = task_environment
if len(task_environment.get_labels(False)) == 1:
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/classification/tasks/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@
class OTXClassificationNNCFTask(ClassificationNNCFTask):
"""Task for compressing classification models using NNCF."""

def __init__(self, task_environment: TaskEnvironment): # pylint: disable=useless-parent-delegation
super().__init__(task_environment)
def __init__(self, task_environment: TaskEnvironment, **kwargs): # pylint: disable=useless-parent-delegation
super().__init__(task_environment, **kwargs)
5 changes: 3 additions & 2 deletions otx/algorithms/common/adapters/mmcv/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# * https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/epoch_based_runner.py
# * https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py

import time
import warnings
from typing import List, Optional, Sequence

Expand Down Expand Up @@ -75,8 +76,8 @@ def train(self, data_loader: DataLoader, **kwargs):
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook("before_train_epoch")
# TODO: uncomment below line or resolve root cause of deadlock issue if multi-GPUs need to be supported.
# time.sleep(2) # Prevent possible multi-gpu deadlock during epoch transition
if self.distributed:
time.sleep(2) # Prevent possible multi-gpu deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook("before_train_iter")
Expand Down
8 changes: 5 additions & 3 deletions otx/algorithms/common/tasks/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,24 @@ class BaseTask(IInferenceTask, IExportTask, IEvaluationTask, IUnload):
_task_environment: TaskEnvironment

@check_input_parameters_type()
def __init__(self, task_config, task_environment: TaskEnvironment):
def __init__(self, task_config, task_environment: TaskEnvironment, output_path: Optional[str] = None):
self._task_config = task_config
self._task_environment = task_environment
self._hyperparams = task_environment.get_hyper_parameters(self._task_config) # type: ConfigDict
self._model_name = task_environment.model_template.name
self._task_type = task_environment.model_template.task_type
self._labels = task_environment.get_labels(include_empty=False)
self._output_path = tempfile.mkdtemp(prefix="OTX-task-")
logger.info(f"created output path at {self._output_path}")
self.confidence_threshold = self._get_confidence_threshold(self._hyperparams)
# Set default model attributes.
self._model_label_schema = [] # type: List[LabelEntity]
self._optimization_methods = [] # type: List[OptimizationMethod]
self._model_ckpt = None
self._data_pipeline_path = None
self._anchors = {} # type: Dict[str, int]
if output_path is None:
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved
output_path = tempfile.mkdtemp(prefix="OTX-task-")
self._output_path = output_path
logger.info(f"created output path at {self._output_path}")
if task_environment.model is not None:
logger.info("loading the model from the task env.")
state_dict = self._load_model_state_dict(self._task_environment.model)
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/detection/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ class DetectionInferenceTask(BaseTask, IInferenceTask, IExportTask, IEvaluationT
"""Inference Task Implementation of OTX Detection."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
# self._should_stop = False
self.train_type = None
super().__init__(DetectionConfig, task_environment)
super().__init__(DetectionConfig, task_environment, **kwargs)
self.template_dir = os.path.abspath(os.path.dirname(self.template_file_path))
self.base_dir = self.template_dir
# TODO Move this to the common
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/detection/tasks/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class DetectionNNCFTask(DetectionInferenceTask, IOptimizationTask):
"""Task for compressing detection models using NNCF."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
super().__init__(task_environment)
def __init__(self, task_environment: TaskEnvironment, **kwargs):
super().__init__(task_environment, **kwargs)
self._val_dataloader = None
self._compression_ctrl = None
self._nncf_preset = "nncf_quantization"
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/segmentation/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ class SegmentationInferenceTask(BaseTask, IInferenceTask, IExportTask, IEvaluati
"""Inference Task Implementation of OTX Segmentation."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
def __init__(self, task_environment: TaskEnvironment, **kwargs):
# self._should_stop = False
self.freeze = True
self.metric = "mDice"
self._label_dictionary = {} # type: Dict
super().__init__(SegmentationConfig, task_environment)
super().__init__(SegmentationConfig, task_environment, **kwargs)

def infer(
self, dataset: DatasetEntity, inference_parameters: Optional[InferenceParameters] = None
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/segmentation/tasks/nncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class SegmentationNNCFTask(SegmentationInferenceTask, IOptimizationTask):
"""Task for compressing object detection models using NNCF."""

@check_input_parameters_type()
def __init__(self, task_environment: TaskEnvironment):
super().__init__(task_environment)
def __init__(self, task_environment: TaskEnvironment, **kwargs):
super().__init__(task_environment, **kwargs)

self._val_dataloader = None
self._compression_ctrl = None
Expand Down
41 changes: 30 additions & 11 deletions otx/cli/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import argparse
import os
import shutil

from otx.api.configuration.helper import create
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import ModelEntity
from otx.api.entities.model_template import TaskType
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
Expand All @@ -38,6 +38,7 @@
read_label_schema,
save_model_data,
)
from otx.cli.utils.multi_gpu import MultiGPUManager
from otx.cli.utils.parser import (
add_hyper_parameters_sub_parser,
gen_params_dict_from_args,
Expand Down Expand Up @@ -113,9 +114,9 @@ def parse_args():
help="Location where trained model will be stored.",
)
parser.add_argument(
"--save-logs-to",
"--work-dir",
required=False,
help="Location where logs will be stored.",
help="Location where the intermediate output of the training will be stored.",
)
parser.add_argument(
"--enable-hpo",
Expand All @@ -128,6 +129,18 @@ def parse_args():
type=float,
help="Expected ratio of total time to run HPO to time taken for full fine-tuning.",
)
parser.add_argument(
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved
"--gpus",
type=str,
help="Comma-separated indices of GPU. \
If there are more than one available GPU, then model is trained with multi GPUs.",
JihwanEom marked this conversation as resolved.
Show resolved Hide resolved
)
parser.add_argument(
"--multi-gpu-port",
harimkang marked this conversation as resolved.
Show resolved Hide resolved
default=25000,
type=int,
help="port for communication beteween multi GPU processes.",
JihwanEom marked this conversation as resolved.
Show resolved Hide resolved
)

add_hyper_parameters_sub_parser(parser, hyper_parameters)

Expand All @@ -136,7 +149,6 @@ def parse_args():

def main():
"""Main function that is used for model training."""

# Dynamically create an argument parser based on override parameters.
args, template, hyper_parameters = parse_args()
# Get new values from user's input.
Expand Down Expand Up @@ -195,9 +207,19 @@ def main():
task = run_hpo(args, environment, dataset, template.task_type)
if task is None:
print("cannot run HPO for this task. will train a model without HPO.")
task = task_class(task_environment=environment)
task = task_class(task_environment=environment, output_path=args.work_dir)
else:
task = task_class(task_environment=environment)
task = task_class(task_environment=environment, output_path=args.work_dir)

if args.gpus:
multigpu_manager = MultiGPUManager(main, args.gpus, str(args.multi_gpu_port))
if template.task_type in (TaskType.ACTION_CLASSIFICATION, TaskType.ACTION_DETECTION):
print("Multi-GPU training for action tasks isn't supported yet. A single GPU will be used for a training.")
elif (
multigpu_manager.is_available()
and not template.task_type.is_anomaly # anomaly tasks don't use this way for multi-GPU training
):
multigpu_manager.setup_multi_gpu_train(task.project_path, hyper_parameters if args.enable_hpo else None)

output_model = ModelEntity(dataset, environment.get_model_configuration())

Expand All @@ -223,11 +245,8 @@ def main():
assert resultset.performance is not None
print(resultset.performance)

if args.save_logs_to:
tmp_path = task.project_path
logs_path = os.path.join(args.save_logs_to, tmp_path.split("/")[-1])
shutil.copytree(tmp_path, logs_path)
print(f"Save logs: {logs_path}")
if args.gpus:
sungmanc marked this conversation as resolved.
Show resolved Hide resolved
multigpu_manager.finalize()


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def run_hpo(args: Namespace, environment: TaskEnvironment, dataset: DatasetEntit
task_class = get_impl_class(environment.model_template.entrypoints.base)
task_class = get_train_wrapper_task(task_class, task_type)

task = task_class(task_environment=environment)
task = task_class(task_environment=environment, output_path=args.work_dir)

hpopt_cfg = _load_hpopt_config(
osp.join(
Expand Down Expand Up @@ -335,8 +335,8 @@ def get_train_wrapper_task(impl_class, task_type):
class HpoTrainTask(impl_class):
"""wrapper class for the HPO."""

def __init__(self, task_environment):
super().__init__(task_environment)
def __init__(self, task_environment, **kwargs):
super().__init__(task_environment, **kwargs)
self._task_type = task_type

# TODO: need to check things below whether works on MPA tasks
Expand Down
Loading