Skip to content

Commit

Permalink
experiment management backend (microsoft#3081)
Browse files Browse the repository at this point in the history
* step 1 nnictl generate experimentId & merge folder

* step 2.1 modify .experiment structure

* step 2.2 add lock for .experiment rw in nnictl

* step 2.2 add filelock dependence

* step 2.2 remove uniqueString from main.js

* fix test bug

* fix test bug

* setp 3.1 add experiment manager

* step 3.2 add getExperimentsInfo

* fix eslint

* add a simple file lock to support stale

* step 3.3 add test

* divide abs experiment manager from manager

* experiment manager refactor

* support .experiment sync update status

* nnictl no longer uses rest api to update status or endtime

* nnictl no longer uses rest api to update status or endtime

* fix eslint

* support .experiment sync update endtime

* fix test

* fix settimeout bug

* fix test

* adjust experiment endTime

* separate simple file lock class

* modify name

* add 'id' in .experiment

* update rest api format

* fix eslint

* fix issue in comments

* fix rest api format

* add indent in json in experiments manager

* fix unittest

* fix unittest

* refector file lock

* fix eslint

* remove '__enter__' in filelock

* filelock support never expire

Co-authored-by: Ning Shang <nishang@microsoft.com>
  • Loading branch information
J-shang and Ning Shang authored Nov 30, 2020
1 parent fc0ff8c commit 95f731e
Show file tree
Hide file tree
Showing 22 changed files with 546 additions and 127 deletions.
36 changes: 36 additions & 0 deletions nni/tools/nnictl/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import sys
import json
import tempfile
import time
import socket
import string
import random
import ruamel.yaml as yaml
import psutil
import filelock
import glob
from colorama import Fore

from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO
Expand Down Expand Up @@ -95,3 +98,36 @@ def generate_folder_name():
temp_dir = generate_folder_name()
os.makedirs(temp_dir)
return temp_dir

class SimplePreemptiveLock(filelock.SoftFileLock):
'''this is a lock support check lock expiration, if you do not need check expiration, you can use SoftFileLock'''
def __init__(self, lock_file, stale=-1):
super(__class__, self).__init__(lock_file, timeout=-1)
self._lock_file_name = '{}.{}'.format(self._lock_file, os.getpid())
self._stale = stale

def _acquire(self):
open_mode = os.O_WRONLY | os.O_CREAT | os.O_EXCL | os.O_TRUNC
try:
lock_file_names = glob.glob(self._lock_file + '.*')
for file_name in lock_file_names:
if os.path.exists(file_name) and (self._stale < 0 or time.time() - os.stat(file_name).st_mtime < self._stale):
return None
fd = os.open(self._lock_file_name, open_mode)
except (IOError, OSError):
pass
else:
self._lock_file_fd = fd
return None

def _release(self):
os.close(self._lock_file_fd)
self._lock_file_fd = None
try:
os.remove(self._lock_file_name)
except OSError:
pass
return None

def get_file_lock(path: string, stale=-1):
return SimplePreemptiveLock(path + '.lock', stale=-1)
72 changes: 44 additions & 28 deletions nni/tools/nnictl/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import os
import json
import shutil
import time
from .constants import NNICTL_HOME_DIR
from .command_utils import print_error
from .common_utils import get_file_lock

class Config:
'''a util class to load and save config'''
Expand Down Expand Up @@ -34,7 +36,7 @@ def write_file(self):
if self.config:
try:
with open(self.config_file, 'w') as file:
json.dump(self.config, file)
json.dump(self.config, file, indent=4)
except IOError as error:
print('Error:', error)
return
Expand All @@ -54,39 +56,53 @@ class Experiments:
def __init__(self, home_dir=NNICTL_HOME_DIR):
os.makedirs(home_dir, exist_ok=True)
self.experiment_file = os.path.join(home_dir, '.experiment')
self.experiments = self.read_file()
self.lock = get_file_lock(self.experiment_file, stale=2)
with self.lock:
self.experiments = self.read_file()

def add_experiment(self, expId, port, startTime, file_name, platform, experiment_name, endTime='N/A', status='INITIALIZED'):
'''set {key:value} paris to self.experiment'''
self.experiments[expId] = {}
self.experiments[expId]['port'] = port
self.experiments[expId]['startTime'] = startTime
self.experiments[expId]['endTime'] = endTime
self.experiments[expId]['status'] = status
self.experiments[expId]['fileName'] = file_name
self.experiments[expId]['platform'] = platform
self.experiments[expId]['experimentName'] = experiment_name
self.write_file()
def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED',
tag=[], pid=None, webuiUrl=[], logDir=[]):
'''set {key:value} pairs to self.experiment'''
with self.lock:
self.experiments = self.read_file()
self.experiments[expId] = {}
self.experiments[expId]['id'] = expId
self.experiments[expId]['port'] = port
self.experiments[expId]['startTime'] = startTime
self.experiments[expId]['endTime'] = endTime
self.experiments[expId]['status'] = status
self.experiments[expId]['platform'] = platform
self.experiments[expId]['experimentName'] = experiment_name
self.experiments[expId]['tag'] = tag
self.experiments[expId]['pid'] = pid
self.experiments[expId]['webuiUrl'] = webuiUrl
self.experiments[expId]['logDir'] = logDir
self.write_file()

def update_experiment(self, expId, key, value):
'''Update experiment'''
if expId not in self.experiments:
return False
self.experiments[expId][key] = value
self.write_file()
return True
with self.lock:
self.experiments = self.read_file()
if expId not in self.experiments:
return False
self.experiments[expId][key] = value
self.write_file()
return True

def remove_experiment(self, expId):
'''remove an experiment by id'''
if expId in self.experiments:
fileName = self.experiments.pop(expId).get('fileName')
if fileName:
logPath = os.path.join(NNICTL_HOME_DIR, fileName)
try:
shutil.rmtree(logPath)
except FileNotFoundError:
print_error('{0} does not exist.'.format(logPath))
self.write_file()
with self.lock:
self.experiments = self.read_file()
if expId in self.experiments:
self.experiments.pop(expId)
fileName = expId
if fileName:
logPath = os.path.join(NNICTL_HOME_DIR, fileName)
try:
shutil.rmtree(logPath)
except FileNotFoundError:
print_error('{0} does not exist.'.format(logPath))
self.write_file()

def get_all_experiments(self):
'''return all of experiments'''
Expand All @@ -96,7 +112,7 @@ def write_file(self):
'''save config to local file'''
try:
with open(self.experiment_file, 'w') as file:
json.dump(self.experiments, file)
json.dump(self.experiments, file, indent=4)
except IOError as error:
print('Error:', error)
return ''
Expand Down
2 changes: 1 addition & 1 deletion nni/tools/nnictl/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from colorama import Fore

NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl')
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')

NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')

Expand Down
74 changes: 35 additions & 39 deletions nni/tools/nnictl/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment

def get_log_path(config_file_name):
def get_log_path(experiment_id):
'''generate stdout and stderr log path'''
stdout_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stdout')
stderr_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stderr')
os.makedirs(os.path.join(NNICTL_HOME_DIR, experiment_id, 'log'), exist_ok=True)
stdout_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
stderr_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
return stdout_full_path, stderr_full_path

def print_log_content(config_file_name):
Expand All @@ -38,7 +39,7 @@ def print_log_content(config_file_name):
print_normal(' Stderr:')
print(check_output_command(stderr_full_path))

def start_rest_server(port, platform, mode, config_file_name, foreground=False, experiment_id=None, log_dir=None, log_level=None):
def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None):
'''Run nni manager process'''
if detect_port(port):
print_error('Port %s is used by another process, please reset the port!\n' \
Expand All @@ -63,7 +64,8 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
node_command = os.path.join(entry_dir, 'node.exe')
else:
node_command = os.path.join(entry_dir, 'node')
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform]
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \
'--experiment_id', experiment_id]
if mode == 'view':
cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true']
Expand All @@ -73,13 +75,12 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
cmds += ['--log_dir', log_dir]
if log_level is not None:
cmds += ['--log_level', log_level]
if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id]
if foreground:
cmds += ['--foreground', 'true']
stdout_full_path, stderr_full_path = get_log_path(config_file_name)
stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
start_time = time.time()
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
#add time information in the header of log files
log_header = LOG_HEADER % str(time_now)
stdout_file.write(log_header)
Expand All @@ -95,7 +96,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
return process, str(time_now)
return process, int(start_time * 1000)

def set_trial_config(experiment_config, port, config_file_name):
'''set trial configuration'''
Expand Down Expand Up @@ -432,9 +433,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)

def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
def launch_experiment(args, experiment_config, mode, experiment_id):
'''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name)
nni_config = Config(experiment_id)
# check packages for tuner
package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
Expand All @@ -445,15 +446,15 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
module_name, _ = get_builtin_module_class_name('advisors', package_name)
if package_name and module_name:
try:
stdout_full_path, stderr_full_path = get_log_path(config_file_name)
stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
except CalledProcessError:
print_error('some errors happen when import package %s.' %(package_name))
print_log_content(config_file_name)
print_log_content(experiment_id)
if package_name in INSTALLABLE_PACKAGE_META:
print_error('If %s is not installed, it should be installed through '\
'\'nnictl package install --name %s\''%(package_name, package_name))
'\'nnictl package install --name %s\'' % (package_name, package_name))
exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
Expand All @@ -465,7 +466,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
log_level = 'debug'
# start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, config_file_name, foreground, experiment_id, log_dir, log_level)
mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation
if experiment_config.get('useAnnotation'):
Expand All @@ -491,7 +492,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_normal('Successfully started Restful server!')
else:
print_error('Restful server start failed!')
print_log_content(config_file_name)
print_log_content(experiment_id)
try:
kill_command(rest_process.pid)
except Exception:
Expand All @@ -500,21 +501,25 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if mode != 'view':
# set platform configuration
set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
config_file_name, rest_process)
experiment_id, rest_process)

# start a new experiment
print_normal('Starting experiment...')
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# set debug configuration
if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug
response = set_experiment(experiment_config, mode, args.port, config_file_name)
response = set_experiment(experiment_config, mode, args.port, experiment_id)
if response:
if experiment_id is None:
experiment_id = json.loads(response.text).get('experiment_id')
nni_config.set_config('experimentId', experiment_id)
else:
print_error('Start experiment failed!')
print_log_content(config_file_name)
print_log_content(experiment_id)
try:
kill_command(rest_process.pid)
except Exception:
Expand All @@ -526,12 +531,6 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list)

# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'])

print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if mode != 'view' and args.foreground:
try:
Expand All @@ -544,8 +543,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen

def create_experiment(args):
'''start a new experiment'''
config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(config_file_name)
experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(experiment_id)
nni_config.set_config('experimentId', experiment_id)
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
Expand All @@ -560,9 +560,9 @@ def create_experiment(args):
nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port)
try:
launch_experiment(args, experiment_config, 'new', config_file_name)
launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception:
nni_config = Config(config_file_name)
nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid')
if restServerPid:
kill_command(restServerPid)
Expand All @@ -589,17 +589,13 @@ def manage_stopped_experiment(args, mode):
exit(1)
experiment_id = args.id
print_normal('{0} experiment {1}...'.format(mode, experiment_id))
nni_config = Config(experiment_dict[experiment_id]['fileName'])
nni_config = Config(experiment_id)
experiment_config = nni_config.get_config('experimentConfig')
experiment_id = nni_config.get_config('experimentId')
new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
new_nni_config = Config(new_config_file_name)
new_nni_config.set_config('experimentConfig', experiment_config)
new_nni_config.set_config('restServerPort', args.port)
nni_config.set_config('restServerPort', args.port)
try:
launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id)
launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception:
nni_config = Config(new_config_file_name)
nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid')
if restServerPid:
kill_command(restServerPid)
Expand Down
4 changes: 4 additions & 0 deletions nni/tools/nnictl/launcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def parse_time(time):
def parse_path(experiment_config, config_path):
'''Parse path in config file'''
expand_path(experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
expand_path(experiment_config, 'logDir')
if experiment_config.get('trial'):
expand_path(experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'):
Expand Down Expand Up @@ -65,6 +67,8 @@ def parse_path(experiment_config, config_path):
root_path = os.path.dirname(config_path)
if experiment_config.get('searchSpacePath'):
parse_relative_path(root_path, experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
parse_relative_path(root_path, experiment_config, 'logDir')
if experiment_config.get('trial'):
parse_relative_path(root_path, experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'):
Expand Down
Loading

0 comments on commit 95f731e

Please sign in to comment.