Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Support multiple experiments of nnictl (#183)
Browse files Browse the repository at this point in the history
* fix nnictl bug

* fix nnictl create bug

* add experiment status logic

* add more information for nnictl

* fix Evolution Tuner bug

* refactor code

* fix code in updater.py

* fix nnictl --help

* fix classArgs bug

* update check response.status_code logic

* show trial log path

* update document

* fix install.sh

* set default vallue for maxTrialNum and maxExecDuration

* fix nnictl

* fix config path hint

* support multiPhase

* fix bash-completion

* refactor bash-completion

* add sklearn-regression

* add search_space

* fix bug

* fix install.sh

* refactor code

* remove unused code

* support multi experiments

* fix issues
  • Loading branch information
SparkSnail authored Oct 10, 2018
1 parent 6ef6511 commit 9a2a168
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 230 deletions.
1 change: 0 additions & 1 deletion tools/bash-completion
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ __nnictl_log_cmds="stdout stderr"
__nnictl_log_stdout_cmds="--tail --head --path"
__nnictl_log_stderr_cmds="--tail --head --path"


# list of arguments that accept a file name
__nnictl_file_args=" --config -c --filename -f "

Expand Down
9 changes: 5 additions & 4 deletions tools/nnicmd/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
import os
import json
import shutil
from .constants import METADATA_DIR, METADATA_FULL_PATH
from .constants import HOME_DIR

class Config:
'''a util class to load and save config'''
def __init__(self):
os.makedirs(METADATA_DIR, exist_ok=True)
self.config_file = METADATA_FULL_PATH
def __init__(self, port):
config_path = os.path.join(HOME_DIR, str(port))
os.makedirs(config_path, exist_ok=True)
self.config_file = os.path.join(config_path, '.config')
self.config = self.read_file()

def get_all_config(self):
Expand Down
15 changes: 1 addition & 14 deletions tools/nnicmd/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,7 @@

import os

REST_PORT = 51188

HOME_DIR = os.path.join(os.environ['HOME'], 'nni')

METADATA_DIR = os.path.join(HOME_DIR, 'nnictl')

METADATA_FULL_PATH = os.path.join(METADATA_DIR, 'metadata')

LOG_DIR = os.path.join(HOME_DIR, 'nnictl', 'log')

STDOUT_FULL_PATH = os.path.join(LOG_DIR, 'stdout')

STDERR_FULL_PATH = os.path.join(LOG_DIR, 'stderr')
HOME_DIR = os.path.join(os.environ['HOME'], '.local', 'nni', 'nnictl')

ERROR_INFO = 'ERROR: %s'

Expand All @@ -44,7 +32,6 @@
'-----------------------------------------------------------------------\n' \
'The experiment id is %s\n'\
'The restful server post is %s\n' \
'The Web UI urls are: %s\n' \
'-----------------------------------------------------------------------\n\n' \
'You can use these commands to get more information about the experiment\n' \
'-----------------------------------------------------------------------\n' \
Expand Down
76 changes: 34 additions & 42 deletions tools/nnicmd/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@
from .config_utils import Config
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, detect_process
from .constants import *
from .webui_utils import start_web_ui, check_web_ui

def start_rest_server(port, platform, mode, experiment_id=None):
'''Run nni manager process'''
print_normal('Checking environment...')
nni_config = Config()
nni_config = Config(port)
rest_port = nni_config.get_config('restServerPort')
running, _ = check_rest_server_quick(rest_port)
if rest_port and running:
Expand All @@ -50,10 +49,10 @@ def start_rest_server(port, platform, mode, experiment_id=None):
cmds = [manager, '--port', str(port), '--mode', platform, '--start_mode', mode]
if mode == 'resume':
cmds += ['--experiment_id', experiment_id]
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
stdout_file = open(STDOUT_FULL_PATH, 'a+')
stderr_file = open(STDERR_FULL_PATH, 'a+')
stdout_full_path = os.path.join(HOME_DIR, str(port), 'stdout')
stderr_full_path = os.path.join(HOME_DIR, str(port), 'stderr')
stdout_file = open(stdout_full_path, 'a+')
stderr_file = open(stderr_full_path, 'a+')
process = Popen(cmds, stdout=stdout_file, stderr=stderr_file)
return process

Expand All @@ -80,7 +79,8 @@ def set_trial_config(experiment_config, port):
return True
else:
print('Error message is {}'.format(response.text))
with open(STDERR_FULL_PATH, 'a+') as fout:
stderr_full_path = os.path.join(HOME_DIR, str(port), 'stderr')
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
return False

Expand All @@ -98,7 +98,8 @@ def set_remote_config(experiment_config, port):
if not response or not check_response(response):
if response is not None:
err_message = response.text
with open(STDERR_FULL_PATH, 'a+') as fout:
stderr_full_path = os.path.join(HOME_DIR, str(port), 'stderr')
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message

Expand Down Expand Up @@ -171,22 +172,23 @@ def set_experiment(experiment_config, mode, port):
if check_response(response):
return response
else:
with open(STDERR_FULL_PATH, 'a+') as fout:
stderr_full_path = os.path.join(HOME_DIR, str(port), 'stderr')
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
print_error('Setting experiment error, error message is {}'.format(response.text))
return None

def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=None):
def launch_experiment(args, experiment_config, mode, experiment_id=None):
'''follow steps to start rest server and start experiment'''
nni_config = Config()
nni_config = Config(args.port)
#Check if there is an experiment running
origin_rest_pid = nni_config.get_config('restServerPid')
if origin_rest_pid and detect_process(origin_rest_pid):
print_error('There is an experiment running, please stop it first...')
print_normal('You can use \'nnictl stop\' command to stop an experiment!')
exit(0)
exit(1)
# start rest server
rest_process = start_rest_server(REST_PORT, experiment_config['trainingServicePlatform'], mode, experiment_id)
rest_process = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, experiment_id)
nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation
if experiment_config.get('useAnnotation'):
Expand All @@ -206,7 +208,7 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
experiment_config['searchSpace'] = json.dumps('')

# check rest server
running, _ = check_rest_server(REST_PORT)
running, _ = check_rest_server(args.port)
if running:
print_normal('Successfully started Restful server!')
else:
Expand All @@ -216,12 +218,12 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
call(cmds)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(0)
exit(1)

# set remote config
if experiment_config['trainingServicePlatform'] == 'remote':
print_normal('Setting remote config...')
config_result, err_msg = set_remote_config(experiment_config, REST_PORT)
config_result, err_msg = set_remote_config(experiment_config, args.port)
if config_result:
print_normal('Success!')
else:
Expand All @@ -231,12 +233,12 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
call(cmds)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(0)
exit(1)

# set local config
if experiment_config['trainingServicePlatform'] == 'local':
print_normal('Setting local config...')
if set_local_config(experiment_config, REST_PORT):
if set_local_config(experiment_config, args.port):
print_normal('Successfully set local config!')
else:
print_error('Failed!')
Expand All @@ -245,12 +247,12 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
call(cmds)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(0)
exit(1)

#set pai config
if experiment_config['trainingServicePlatform'] == 'pai':
print_normal('Setting pai config...')
config_result, err_msg = set_pai_config(experiment_config, REST_PORT)
config_result, err_msg = set_pai_config(experiment_config, args.port)
if config_result:
print_normal('Successfully set pai config!')
else:
Expand All @@ -261,22 +263,11 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
call(cmds)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(0)

#start webui
if check_web_ui():
print_warning('{0} {1}'.format(' '.join(nni_config.get_config('webuiUrl')),'is being used, please stop it first!'))
print_normal('You can use \'nnictl webui stop\' to stop old Web UI process...')
else:
print_normal('Starting Web UI...')
webui_process = start_web_ui(webuiport)
if webui_process:
nni_config.set_config('webuiPid', webui_process.pid)
print_normal('Successfully started Web UI!')
exit(1)

# start a new experiment
print_normal('Starting experiment...')
response = set_experiment(experiment_config, mode, REST_PORT)
response = set_experiment(experiment_config, mode, args.port)
if response:
if experiment_id is None:
experiment_id = json.loads(response.text).get('experiment_id')
Expand All @@ -286,27 +277,28 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
try:
cmds = ['pkill', '-P', str(rest_process.pid)]
call(cmds)
cmds = ['pkill', '-P', str(webui_process.pid)]
call(cmds)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(0)
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, REST_PORT, ' '.join(nni_config.get_config('webuiUrl'))))
exit(1)
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, args.port))

def resume_experiment(args):
'''resume an experiment'''
nni_config = Config()
nni_config = Config(args.port)
experiment_config = nni_config.get_config('experimentConfig')
experiment_id = nni_config.get_config('experimentId')
launch_experiment(args, experiment_config, 'resume', args.webuiport, experiment_id)
launch_experiment(args, experiment_config, 'resume', experiment_id)

def create_experiment(args):
'''start a new experiment'''
nni_config = Config()
nni_config = Config(args.port)
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)
validate_all_content(experiment_config, config_path)

nni_config.set_config('experimentConfig', experiment_config)
launch_experiment(args, experiment_config, 'new', args.webuiport)
nni_config.set_config('restServerPort', REST_PORT)
launch_experiment(args, experiment_config, 'new')
nni_config.set_config('restServerPort', args.port)
45 changes: 22 additions & 23 deletions tools/nnicmd/nnictl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,87 +39,86 @@ def parse_args():
# parse start command
parser_start = subparsers.add_parser('create', help='create a new experiment')
parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_start.add_argument('--webuiport', '-w', default=8080, dest='webuiport')
parser_start.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_start.set_defaults(func=create_experiment)

# parse resume command
parser_resume = subparsers.add_parser('resume', help='resume a new experiment')
parser_resume.add_argument('--experiment', '-e', dest='id', help='ID of the experiment you want to resume')
parser_resume.add_argument('--manager', '-m', default='nnimanager', dest='manager')
parser_resume.add_argument('--webuiport', '-w', default=8080, dest='webuiport')
parser_resume.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_resume.set_defaults(func=resume_experiment)

# parse update command
parser_updater = subparsers.add_parser('update', help='update the experiment')
#add subparsers for parser_updater
parser_updater_subparsers = parser_updater.add_subparsers()
parser_updater_searchspace = parser_updater_subparsers.add_parser('searchspace', help='update searchspace')
parser_updater_searchspace.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_updater_searchspace.add_argument('--filename', '-f', required=True)
parser_updater_searchspace.set_defaults(func=update_searchspace)
parser_updater_searchspace = parser_updater_subparsers.add_parser('concurrency', help='update concurrency')
parser_updater_searchspace.add_argument('--value', '-v', required=True)
parser_updater_searchspace.set_defaults(func=update_concurrency)
parser_updater_searchspace = parser_updater_subparsers.add_parser('duration', help='update duration')
parser_updater_searchspace.add_argument('--value', '-v', required=True)
parser_updater_searchspace.set_defaults(func=update_duration)
parser_updater_concurrency = parser_updater_subparsers.add_parser('concurrency', help='update concurrency')
parser_updater_concurrency.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_updater_concurrency.add_argument('--value', '-v', required=True)
parser_updater_concurrency.set_defaults(func=update_concurrency)
parser_updater_duration = parser_updater_subparsers.add_parser('duration', help='update duration')
parser_updater_duration.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_updater_duration.add_argument('--value', '-v', required=True)
parser_updater_duration.set_defaults(func=update_duration)

#parse stop command
parser_stop = subparsers.add_parser('stop', help='stop the experiment')
parser_stop.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_stop.set_defaults(func=stop_experiment)

#parse trial command
parser_trial = subparsers.add_parser('trial', help='get trial information')
#add subparsers for parser_trial
parser_trial_subparsers = parser_trial.add_subparsers()
parser_trial_ls = parser_trial_subparsers.add_parser('ls', help='list trial jobs')
parser_trial_ls.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_trial_ls.set_defaults(func=trial_ls)
parser_trial_kill = parser_trial_subparsers.add_parser('kill', help='kill trial jobs')
parser_trial_kill.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_trial_kill.add_argument('--trialid', '-t', required=True, dest='trialid', help='the id of trial to be killed')
parser_trial_kill.set_defaults(func=trial_kill)

#TODO:finish webui function
#parse board command
parser_webui = subparsers.add_parser('webui', help='get web ui information')
#add subparsers for parser_board
parser_webui_subparsers = parser_webui.add_subparsers()
parser_webui_start = parser_webui_subparsers.add_parser('start', help='start web ui')
parser_webui_start.add_argument('--port', '-p', dest='port', default=8080, help='the port of web ui')
parser_webui_start.set_defaults(func=start_webui)
parser_webui_stop = parser_webui_subparsers.add_parser('stop', help='stop web ui')
parser_webui_stop.set_defaults(func=stop_webui)
parser_webui_url = parser_webui_subparsers.add_parser('url', help='show the url of web ui')
parser_webui_url.set_defaults(func=webui_url)

#parse experiment command
parser_experiment = subparsers.add_parser('experiment', help='get experiment information')
#add subparsers for parser_experiment
parser_experiment_subparsers = parser_experiment.add_subparsers()
parser_experiment_show = parser_experiment_subparsers.add_parser('show', help='show the information of experiment')
parser_experiment_show.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_experiment_show.set_defaults(func=list_experiment)
parser_experiment_status = parser_experiment_subparsers.add_parser('status', help='show the status of experiment')
parser_experiment_status.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_experiment_status.set_defaults(func=experiment_status)

#parse config command
parser_config = subparsers.add_parser('config', help='get config information')
parser_config_subparsers = parser_config.add_subparsers()
parser_config_show = parser_config_subparsers.add_parser('show', help='show the information of config')
parser_config_show.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_config_show.set_defaults(func=get_config)

#parse log command
parser_log = subparsers.add_parser('log', help='get log information')
# add subparsers for parser_log
parser_log_subparsers = parser_log.add_subparsers()
parser_log_stdout = parser_log_subparsers.add_parser('stdout', help='get stdout information')
parser_log_stdout.add_argument('--port', default=51188, dest='port', help='the port of restful server')
parser_log_stdout.add_argument('--tail', '-T', dest='tail', type=int, help='get tail -100 content of stdout')
parser_log_stdout.add_argument('--head', '-H', dest='head', type=int, help='get head -100 content of stdout')
parser_log_stdout.add_argument('--path', '-p', action='store_true', default=False, help='get the path of stdout file')
parser_log_stdout.add_argument('--path', action='store_true', default=False, help='get the path of stdout file')
parser_log_stdout.set_defaults(func=log_stdout)
parser_log_stderr = parser_log_subparsers.add_parser('stderr', help='get stderr information')
parser_log_stderr.add_argument('--port', default=51188, dest='port', help='the port of restful server')
parser_log_stderr.add_argument('--tail', '-T', dest='tail', type=int, help='get tail -100 content of stderr')
parser_log_stderr.add_argument('--head', '-H', dest='head', type=int, help='get head -100 content of stderr')
parser_log_stderr.add_argument('--path', '-p', action='store_true', default=False, help='get the path of stderr file')
parser_log_stderr.add_argument('--path', action='store_true', default=False, help='get the path of stderr file')
parser_log_stderr.set_defaults(func=log_stderr)
parser_log_trial = parser_log_subparsers.add_parser('trial', help='get trial log path')
parser_log_trial.add_argument('--port', '-p', default=51188, dest='port', help='the port of restful server')
parser_log_trial.add_argument('--id', '-I', dest='id', help='find trial log path by id')
parser_log_trial.set_defaults(func=log_trial)

Expand Down
Loading

0 comments on commit 9a2a168

Please sign in to comment.