Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[Retiarii] Export topk models (#3464)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster authored Apr 3, 2021
1 parent 0494cae commit aea98dd
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 10 deletions.
9 changes: 8 additions & 1 deletion nni/retiarii/execution/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT license.

import time
from typing import Iterable

from ..graph import Model, ModelStatus
from .interface import AbstractExecutionEngine
Expand All @@ -11,7 +12,7 @@
_default_listener = None

__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources',
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec']

def set_execution_engine(engine) -> None:
Expand Down Expand Up @@ -43,6 +44,12 @@ def submit_models(*models: Model) -> None:
engine.submit_models(*models)


def list_models(*models: Model) -> Iterable[Model]:
engine = get_execution_engine()
get_and_register_default_listener(engine)
return engine.list_models()


def wait_models(*models: Model) -> None:
get_and_register_default_listener(get_execution_engine())
while True:
Expand Down
7 changes: 6 additions & 1 deletion nni/retiarii/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import random
import string
from typing import Dict, List
from typing import Dict, Iterable, List

from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils
Expand Down Expand Up @@ -53,13 +53,18 @@ def __init__(self) -> None:
advisor.final_metric_callback = self._final_metric_callback

self._running_models: Dict[int, Model] = dict()
self._history: List[Model] = []

self.resources = 0

def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
self._running_models[send_trial(data.dump())] = model
self._history.append(model)

def list_models(self) -> Iterable[Model]:
return self._history

def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
Expand Down
5 changes: 4 additions & 1 deletion nni/retiarii/execution/cgo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import logging
from typing import List, Dict, Tuple
from typing import Iterable, List, Dict, Tuple

from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
Expand Down Expand Up @@ -58,6 +58,9 @@ def submit_models(self, *models: List[Model]) -> None:
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model

def list_models(self) -> Iterable[Model]:
raise NotImplementedError

def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
Expand Down
11 changes: 10 additions & 1 deletion nni/retiarii/execution/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, NewType, List, Union
from typing import Any, Iterable, NewType, List, Union

from ..graph import Model, MetricData

Expand Down Expand Up @@ -104,6 +104,15 @@ def submit_models(self, *models: Model) -> None:
"""
raise NotImplementedError

@abstractmethod
def list_models(self) -> Iterable[Model]:
"""
Get all models in submitted.
Execution engine should store a copy of models that have been submitted and return a list of copies in this method.
"""
raise NotImplementedError

@abstractmethod
def query_available_resource(self) -> Union[List[WorkerInfo], int]:
"""
Expand Down
27 changes: 22 additions & 5 deletions nni/retiarii/experiment/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command

from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..execution import list_models
from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
Expand Down Expand Up @@ -257,16 +259,31 @@ def stop(self) -> None:
self._dispatcher_thread = None
_logger.info('Experiment stopped')

def export_top_models(self, top_n: int = 1):
def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', formatter: str = 'code') -> Any:
"""
export several top performing models
Export several top performing models.
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` asnd ``formater`` is
available for customization.
top_k : int
How many models are intended to be exported.
optimize_mode : str
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Only model code is supported for now. Not supported by one-shot algorithms.
"""
if top_n != 1:
_logger.warning('Only support top_n is 1 for now.')
if isinstance(self.trainer, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.trainer.export()
else:
_logger.info('For this experiment, you can find out the best one from WebUI.')
all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize')
assert formatter == 'code', 'Export formatter other than "code" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]

def retrain_model(self, model):
"""
Expand Down
5 changes: 4 additions & 1 deletion test/retiarii_test/mnist/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def forward(self, x):
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10
exp_config.max_trial_number = 2
exp_config.training_service.use_active_gpu = False

exp.run(exp_config, 8081 + random.randint(0, 100))
print('Final model:')
for model_code in exp.export_top_models():
print(model_code)
3 changes: 3 additions & 0 deletions test/ut/retiarii/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def submit_models(self, *models: Model) -> None:
self._resource_left -= 1
threading.Thread(target=self._model_complete, args=(model, )).start()

def list_models(self) -> List[Model]:
return self.models

def query_available_resource(self) -> Union[List[WorkerInfo], int]:
return self._resource_left

Expand Down

0 comments on commit aea98dd

Please sign in to comment.