Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Retry a failed multi-model trial by disabling CGO in CGOExecutionEngine #4098

Merged
merged 20 commits into from
Oct 11, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 59 additions & 22 deletions nni/retiarii/execution/cgo_engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.


import logging
import os
import random
import string
import time
import threading
from typing import Iterable, List, Dict, Tuple
from dataclasses import dataclass

from nni.common.device import GPUDevice, Device
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
Expand All @@ -17,13 +19,20 @@
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from ..evaluator.pytorch.lightning import Lightning
from ..evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
from ..evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule

from .base import BaseGraphData

_logger = logging.getLogger(__name__)


@dataclass
class TrialSubmission:
model: Model
placement: Dict[Node, Device]
grouped_models: List[Model]


class CGOExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with Cross-Graph Optimization (CGO).
Expand Down Expand Up @@ -64,7 +73,8 @@ def __init__(self, devices: List[Device] = None,

self._history: List[Model] = []

self._queuing_jobs: List[Model] = []
self._queuing_models: List[Model] = []
self._models_to_retry: List[Model] = []
self._queue_lock = threading.Lock()

# register advisor callbacks
Expand All @@ -76,7 +86,7 @@ def __init__(self, devices: List[Device] = None,
advisor.final_metric_callback = self._final_metric_callback

self._stopped = False
self._consumer_thread = threading.Thread(target=self._consume_queue)
self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start()

def join(self):
Expand All @@ -90,27 +100,45 @@ def submit_models(self, *models: List[Model]) -> None:
curr_time = time.time()
_logger.info('%d models are submitted', len(models))
self._queue_lock.acquire()
self._queuing_jobs.extend([(curr_time, _) for _ in models])
self._queuing_models.extend([(curr_time, _) for _ in models])
self._queue_lock.release()

def _submit_retry_models(self, models: List[Model]) -> None:
_logger.info('%d models are retried', len(models))
self._queue_lock.acquire()
self._models_to_retry.extend(models)
self._queue_lock.release()

def _consume_queue(self):
# a thread to monitor self.queuing_jobs to consume them in batch
def _consume_models(self):
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
while not self._stopped:
if len(self._queuing_jobs) > 0:
curr_time = time.time()
if len(self._models_to_retry) > 0:
self._queue_lock.acquire()
# retrying jobs should be first scheduled.
for m in self._models_to_retry:
if len(self.available_devices) > 0:
self._submit_models_in_batch(m) # submit the single model to avoid cross-graph optimization.
self._models_to_retry = self._models_to_retry[1:]
self._queue_lock.release()

if len(self._queuing_models) > 0:
self._queue_lock.acquire()
if (self.max_concurrency and len(self._queuing_jobs) >= self.max_concurrency):
self._submit_models_in_batch(*[_[1] for _ in self._queuing_jobs[:self.max_concurrency]])
self._queuing_jobs = self._queuing_jobs[self.max_concurrency:]
elif len(self.available_devices) <= len(self._queuing_jobs) or \
(curr_time - self._queuing_jobs[0][0] > self._batch_waiting_time):
self._submit_models_in_batch(*[_[1] for _ in self._queuing_jobs])
self._queuing_jobs = []
curr_time = time.time()

num_models_to_submit = len(self.available_devices)
if self.max_concurrency:
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)

if curr_time - self._queuing_models[0][0] > self._batch_waiting_time:
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
if num_models_to_submit > 0:
self._submit_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]])
self._queuing_models = self._queuing_models[num_models_to_submit:]
self._queue_lock.release()
time.sleep(1)

def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):
unique_gpus = sorted(list(set([ e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
unique_gpus = sorted(list(set([e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
placement_constraint = None
if len(unique_gpus) > 0:
placement_constraint = {}
Expand All @@ -120,6 +148,7 @@ def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):

def _submit_models_in_batch(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted in batch', len(models))
_logger.debug('model id: %s', str([m.model_id for m in models]))
logical = self._build_logical(models)

for opt in self._optimizers:
Expand Down Expand Up @@ -205,15 +234,23 @@ def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
models_to_retry = []
for model_id in self._original_model_to_multi_model:
if self._original_model_to_multi_model[model_id] == model:
original_model = self._original_models[model_id]
if success:
original_model.status = ModelStatus.Trained
else:
original_model.status = ModelStatus.Failed
# the failed models in a multi-model will be retried one by one w/o CGO
if len(self._trial_to_original_models[trial_id]) > 1:
models_to_retry.append(original_model)
for listener in self._listeners:
listener.on_training_end(original_model, success)

if len(models_to_retry) > 0:
self._submit_retry_models(models_to_retry)

self.available_devices.extend(self._trial_used_devices[trial_id])
self.available_devices = sorted(list(set(self.available_devices)))
del self._running_models[trial_id]
Expand Down Expand Up @@ -242,8 +279,11 @@ def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
listener.on_metric(self._original_models[model_id], merged_metrics[model_id])

def query_available_resource(self) -> List[WorkerInfo]:
# the _queuing_jobs need to use available_devices first
return len(self.available_devices) - len(self._queuing_jobs)
# the _queuing_models need to use available_devices first
self._queue_lock.acquire()
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
self._queue_lock.release()
return available_for_more_models

def budget_exhausted(self) -> bool:
advisor = get_advisor()
Expand All @@ -269,9 +309,6 @@ def trial_execute_graph(cls) -> None:
os.remove(file_name)





class AssemblePolicy:
@staticmethod
def _is_related_node(model: Model, node: Node):
Expand Down Expand Up @@ -299,7 +336,7 @@ def _check_graph_connectivity(model: Model,
@staticmethod
def _check_evaluator(new_model: Model, group_model: Dict[Model, Device]) -> bool:
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, MultiModelSupervisedLearningModule)):
and isinstance(new_model.evaluator.module, _MultiModelSupervisedLearningModule)):
return False
for m in group_model:
if not m.evaluator == new_model.evaluator:
Expand Down