From 4900e3b78e75981a1d726022481ce73ba826bc7f Mon Sep 17 00:00:00 2001 From: Hao Ni Date: Wed, 12 May 2021 14:05:58 +0800 Subject: [PATCH 1/5] add prefix url --- nni/tools/nnictl/launcher.py | 10 ++++++--- nni/tools/nnictl/nnictl.py | 1 + nni/tools/nnictl/url_utils.py | 10 +++++++++ .../common/experimentStartupInfo.ts | 22 +++++++++++++++---- ts/nni_manager/main.ts | 8 ++++--- ts/nni_manager/rest_server/nniRestServer.ts | 3 ++- 6 files changed, 43 insertions(+), 11 deletions(-) diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index 7b2d20f9b8..5b3acf28c7 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -16,7 +16,7 @@ import nni_node # pylint: disable=import-error from .launcher_utils import validate_all_content from .rest_utils import rest_put, rest_post, check_rest_server, check_response -from .url_utils import cluster_metadata_url, experiment_url, get_local_urls +from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, set_prefix_url from .config_utils import Config, Experiments from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \ detect_port, get_user @@ -43,7 +43,7 @@ def print_log_content(config_file_name): print_normal(' Stderr:') print(check_output_command(stderr_full_path)) -def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None): +def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None, url_prefix=None): '''Run nni manager process''' if detect_port(port): print_error('Port %s is used by another process, please reset the port!\n' \ @@ -81,6 +81,10 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log cmds += ['--log_level', log_level] if foreground: cmds += ['--foreground', 'true'] + if url_prefix: + set_prefix_url(url_prefix) + cmds += ['--url_prefix', url_prefix] + stdout_full_path, stderr_full_path = get_log_path(experiment_id) with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: start_time = time.time() @@ -384,7 +388,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi platform = experiment_config['trainingService']['platform'] rest_process, start_time = start_rest_server(args.port, platform, \ - mode, experiment_id, foreground, log_dir, log_level) + mode, experiment_id, foreground, log_dir, log_level, args.url_prefix) # save experiment information Experiments().add_experiment(experiment_id, args.port, start_time, platform, diff --git a/nni/tools/nnictl/nnictl.py b/nni/tools/nnictl/nnictl.py index fd8697337e..d4b55d506a 100644 --- a/nni/tools/nnictl/nnictl.py +++ b/nni/tools/nnictl/nnictl.py @@ -54,6 +54,7 @@ def parse_args(): parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file') parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', type=int, help='the port of restful server') parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode') + parser_start.add_argument('--url_prefix', '-u', dest='url_prefix', help=' set prefix url') parser_start.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal') parser_start.set_defaults(func=create_experiment) diff --git a/nni/tools/nnictl/url_utils.py b/nni/tools/nnictl/url_utils.py index 59a28837a6..091fcb9de3 100644 --- a/nni/tools/nnictl/url_utils.py +++ b/nni/tools/nnictl/url_utils.py @@ -3,6 +3,7 @@ import socket import psutil +import re BASE_URL = 'http://localhost' @@ -24,6 +25,15 @@ METRIC_DATA_API = '/metric-data' +def path_validation(path): + assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid." + +def set_prefix_url(prefix): + '''set prefix url''' + path_validation(prefix) + global API_ROOT_URL + API_ROOT_URL = '{0}/{1}'.format(API_ROOT_URL, prefix) + def metric_data_url(port): '''get metric_data url''' return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API) diff --git a/ts/nni_manager/common/experimentStartupInfo.ts b/ts/nni_manager/common/experimentStartupInfo.ts index 5316abd26e..c56afb9409 100644 --- a/ts/nni_manager/common/experimentStartupInfo.ts +++ b/ts/nni_manager/common/experimentStartupInfo.ts @@ -19,8 +19,9 @@ class ExperimentStartupInfo { private readonly: boolean = false; private dispatcherPipe: string | null = null; private platform: string = ''; + private urlprefix: string = ''; - public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string): void { + public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void { assert(!this.initialized); assert(experimentId.trim().length > 0); this.newExperiment = newExperiment; @@ -46,6 +47,10 @@ class ExperimentStartupInfo { if (dispatcherPipe != undefined && dispatcherPipe.length > 0) { this.dispatcherPipe = dispatcherPipe; } + + if(urlprefix != undefined && urlprefix.length > 0){ + this.urlprefix = urlprefix; + } } public getExperimentId(): string { @@ -94,6 +99,11 @@ class ExperimentStartupInfo { assert(this.initialized); return this.dispatcherPipe; } + + public getUrlPrefix(): string { + assert(this.initialized); + return this.urlprefix; + } } function getExperimentId(): string { @@ -117,9 +127,9 @@ function getExperimentStartupInfo(): ExperimentStartupInfo { } function setExperimentStartupInfo( - newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string): void { + newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void { component.get(ExperimentStartupInfo) - .setStartupInfo(newExperiment, experimentId, basePort, platform, logDir, logLevel, readonly, dispatcherPipe); + .setStartupInfo(newExperiment, experimentId, basePort, platform, logDir, logLevel, readonly, dispatcherPipe, urlprefix); } function isReadonly(): boolean { @@ -130,7 +140,11 @@ function getDispatcherPipe(): string | null { return component.get(ExperimentStartupInfo).getDispatcherPipe(); } +function getUrlPrefix(): string { + return component.get(ExperimentStartupInfo).getUrlPrefix(); +} + export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getPlatform, getExperimentStartupInfo, - setExperimentStartupInfo, isReadonly, getDispatcherPipe + setExperimentStartupInfo, isReadonly, getDispatcherPipe, getUrlPrefix }; diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index 3b6c9bae90..c104c3a190 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -25,9 +25,9 @@ import { NNIRestServer } from './rest_server/nniRestServer'; function initStartupInfo( startExpMode: string, experimentId: string, basePort: number, platform: string, - logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string): void { + logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string, urlprefix: string): void { const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW); - setExperimentStartupInfo(createNew, experimentId, basePort, platform, logDirectory, experimentLogLevel, readonly, dispatcherPipe); + setExperimentStartupInfo(createNew, experimentId, basePort, platform, logDirectory, experimentLogLevel, readonly, dispatcherPipe, urlprefix); } async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise { @@ -122,7 +122,9 @@ const readonly = readonlyArg.toLowerCase() == 'true' ? true : false; const dispatcherPipe: string = parseArg(['--dispatcher_pipe']); -initStartupInfo(startMode, experimentId, port, mode, logDir, logLevel, readonly, dispatcherPipe); +const urlPrefix: string = parseArg(['--url_prefix']); + +initStartupInfo(startMode, experimentId, port, mode, logDir, logLevel, readonly, dispatcherPipe, urlPrefix); mkDirP(getLogDir()) .then(async () => { diff --git a/ts/nni_manager/rest_server/nniRestServer.ts b/ts/nni_manager/rest_server/nniRestServer.ts index d9e185773b..b14c081bd5 100644 --- a/ts/nni_manager/rest_server/nniRestServer.ts +++ b/ts/nni_manager/rest_server/nniRestServer.ts @@ -10,6 +10,7 @@ import * as component from '../common/component'; import { RestServer } from '../common/restServer' import { getLogDir } from '../common/utils'; import { createRestHandler } from './restHandler'; +import { getUrlPrefix } from '../common/experimentStartupInfo'; /** * NNI Main rest server, provides rest API to support @@ -19,7 +20,7 @@ import { createRestHandler } from './restHandler'; */ @component.Singleton export class NNIRestServer extends RestServer { - private readonly API_ROOT_URL: string = '/api/v1/nni'; + private readonly API_ROOT_URL: string = `/api/v1/nni/${getUrlPrefix()}`; private readonly LOGS_ROOT_URL: string = '/logs'; /** From f1bb0660a6795f02c3c2bf57e842e40750195c7e Mon Sep 17 00:00:00 2001 From: Hao Ni Date: Fri, 14 May 2021 13:49:00 +0800 Subject: [PATCH 2/5] change pre-setting to call with prefix parameter --- nni/tools/nnictl/config_utils.py | 3 +- nni/tools/nnictl/launcher.py | 81 +++++++++++++++++--------------- nni/tools/nnictl/nnictl_utils.py | 49 ++++++++++--------- nni/tools/nnictl/rest_utils.py | 8 ++-- nni/tools/nnictl/updater.py | 15 +++--- nni/tools/nnictl/url_utils.py | 43 ++++++++--------- 6 files changed, 105 insertions(+), 94 deletions(-) diff --git a/nni/tools/nnictl/config_utils.py b/nni/tools/nnictl/config_utils.py index 5cef7e2449..154b1d8ccb 100644 --- a/nni/tools/nnictl/config_utils.py +++ b/nni/tools/nnictl/config_utils.py @@ -108,7 +108,7 @@ def __init__(self, home_dir=NNI_HOME_DIR): self.experiments = self.read_file() def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED', - tag=[], pid=None, webuiUrl=[], logDir=''): + tag=[], pid=None, webuiUrl=[], logDir='', prefixUrl=None): '''set {key:value} pairs to self.experiment''' with self.lock: self.experiments = self.read_file() @@ -124,6 +124,7 @@ def add_experiment(self, expId, port, startTime, platform, experiment_name, endT self.experiments[expId]['pid'] = pid self.experiments[expId]['webuiUrl'] = webuiUrl self.experiments[expId]['logDir'] = str(logDir) + self.experiments[expId]['prefixUrl'] = prefixUrl self.write_file() def update_experiment(self, expId, key, value): diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index 5b3acf28c7..9fd1c56e0f 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -16,7 +16,7 @@ import nni_node # pylint: disable=import-error from .launcher_utils import validate_all_content from .rest_utils import rest_put, rest_post, check_rest_server, check_response -from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, set_prefix_url +from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, path_validation from .config_utils import Config, Experiments from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \ detect_port, get_user @@ -82,7 +82,7 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log if foreground: cmds += ['--foreground', 'true'] if url_prefix: - set_prefix_url(url_prefix) + path_validation(url_prefix) cmds += ['--url_prefix', url_prefix] stdout_full_path, stderr_full_path = get_log_path(experiment_id) @@ -106,11 +106,11 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file) return process, int(start_time * 1000) -def set_trial_config(experiment_config, port, config_file_name): +def set_trial_config(experiment_config, port, config_file_name, prefixUrl): '''set trial configuration''' request_data = dict() request_data['trial_config'] = experiment_config['trial'] - response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(request_data), REST_TIME_OUT) if check_response(response): return True else: @@ -121,12 +121,12 @@ def set_trial_config(experiment_config, port, config_file_name): fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':'))) return False -def set_adl_config(experiment_config, port, config_file_name): +def set_adl_config(experiment_config, port, config_file_name, prefixUrl): '''set adl configuration''' adl_config_data = dict() # hack for supporting v2 config, need refactor adl_config_data['adl_config'] = {} - response = rest_put(cluster_metadata_url(port), json.dumps(adl_config_data), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(adl_config_data), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -136,11 +136,11 @@ def set_adl_config(experiment_config, port, config_file_name): fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) return False, err_message set_V1_common_config(experiment_config, port, config_file_name) - result, message = setNNIManagerIp(experiment_config, port, config_file_name) + result, message = setNNIManagerIp(experiment_config, port, config_file_name, prefixUrl) if not result: return result, message #set trial_config - return set_trial_config(experiment_config, port, config_file_name), None + return set_trial_config(experiment_config, port, config_file_name, prefixUrl), None def validate_response(response, config_file_name): err_message = None @@ -154,7 +154,7 @@ def validate_response(response, config_file_name): exit(1) # hack to fix v1 version_check and log_collection bug, need refactor -def set_V1_common_config(experiment_config, port, config_file_name): +def set_V1_common_config(experiment_config, port, config_file_name, prefixUrl): version_check = True #debug mode should disable version check if experiment_config.get('debug') is not None: @@ -162,19 +162,19 @@ def set_V1_common_config(experiment_config, port, config_file_name): #validate version check if experiment_config.get('versionCheck') is not None: version_check = experiment_config.get('versionCheck') - response = rest_put(cluster_metadata_url(port), json.dumps({'version_check': version_check}), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps({'version_check': version_check}), REST_TIME_OUT) validate_response(response, config_file_name) if experiment_config.get('logCollection'): - response = rest_put(cluster_metadata_url(port), json.dumps({'log_collection': experiment_config.get('logCollection')}), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps({'log_collection': experiment_config.get('logCollection')}), REST_TIME_OUT) validate_response(response, config_file_name) -def setNNIManagerIp(experiment_config, port, config_file_name): +def setNNIManagerIp(experiment_config, port, config_file_name, prefixUrl): '''set nniManagerIp''' if experiment_config.get('nniManagerIp') is None: return True, None ip_config_dict = dict() ip_config_dict['nni_manager_ip'] = {'nniManagerIp': experiment_config['nniManagerIp']} - response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(ip_config_dict), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -185,11 +185,11 @@ def setNNIManagerIp(experiment_config, port, config_file_name): return False, err_message return True, None -def set_kubeflow_config(experiment_config, port, config_file_name): +def set_kubeflow_config(experiment_config, port, config_file_name, prefixUrl): '''set kubeflow configuration''' kubeflow_config_data = dict() kubeflow_config_data['kubeflow_config'] = experiment_config['kubeflowConfig'] - response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(kubeflow_config_data), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -198,18 +198,18 @@ def set_kubeflow_config(experiment_config, port, config_file_name): with open(stderr_full_path, 'a+') as fout: fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) return False, err_message - set_V1_common_config(experiment_config, port, config_file_name) - result, message = setNNIManagerIp(experiment_config, port, config_file_name) + set_V1_common_config(experiment_config, port, config_file_name, prefixUrl) + result, message = setNNIManagerIp(experiment_config, port, config_file_name, prefixUrl) if not result: return result, message #set trial_config - return set_trial_config(experiment_config, port, config_file_name), err_message + return set_trial_config(experiment_config, port, config_file_name, prefixUrl), err_message -def set_frameworkcontroller_config(experiment_config, port, config_file_name): +def set_frameworkcontroller_config(experiment_config, port, config_file_name, prefixUrl): '''set kubeflow configuration''' frameworkcontroller_config_data = dict() frameworkcontroller_config_data['frameworkcontroller_config'] = experiment_config['frameworkcontrollerConfig'] - response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -218,16 +218,16 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name): with open(stderr_full_path, 'a+') as fout: fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) return False, err_message - set_V1_common_config(experiment_config, port, config_file_name) - result, message = setNNIManagerIp(experiment_config, port, config_file_name) + set_V1_common_config(experiment_config, port, config_file_name, prefixUrl) + result, message = setNNIManagerIp(experiment_config, port, config_file_name, prefixUrl) if not result: return result, message #set trial_config - return set_trial_config(experiment_config, port, config_file_name), err_message + return set_trial_config(experiment_config, port, config_file_name, prefixUrl), err_message -def set_shared_storage(experiment_config, port, config_file_name): +def set_shared_storage(experiment_config, port, config_file_name, prefixUrl): if 'sharedStorage' in experiment_config: - response = rest_put(cluster_metadata_url(port), json.dumps({'shared_storage_config': experiment_config['sharedStorage']}), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps({'shared_storage_config': experiment_config['sharedStorage']}), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -238,7 +238,7 @@ def set_shared_storage(experiment_config, port, config_file_name): return False, err_message return True, None -def set_experiment_v1(experiment_config, mode, port, config_file_name): +def set_experiment_v1(experiment_config, mode, port, config_file_name, prefixUrl): '''Call startExperiment (rest POST /experiment) with yaml file content''' request_data = dict() request_data['authorName'] = experiment_config['authorName'] @@ -298,7 +298,7 @@ def set_experiment_v1(experiment_config, mode, port, config_file_name): elif experiment_config['trainingServicePlatform'] == 'adl': request_data['clusterMetaData'].append( {'key': 'trial_config', 'value': experiment_config['trial']}) - response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True) + response = rest_post(experiment_url(port, prefixUrl), json.dumps(request_data), REST_TIME_OUT, show_error=True) if check_response(response): return response else: @@ -309,9 +309,9 @@ def set_experiment_v1(experiment_config, mode, port, config_file_name): print_error('Setting experiment error, error message is {}'.format(response.text)) return None -def set_experiment_v2(experiment_config, mode, port, config_file_name): +def set_experiment_v2(experiment_config, mode, port, config_file_name, prefixUrl): '''Call startExperiment (rest POST /experiment) with yaml file content''' - response = rest_post(experiment_url(port), json.dumps(experiment_config), REST_TIME_OUT, show_error=True) + response = rest_post(experiment_url(port, prefixUrl), json.dumps(experiment_config), REST_TIME_OUT, show_error=True) if check_response(response): return response else: @@ -322,21 +322,21 @@ def set_experiment_v2(experiment_config, mode, port, config_file_name): print_error('Setting experiment error, error message is {}'.format(response.text)) return None -def set_platform_config(platform, experiment_config, port, config_file_name, rest_process): +def set_platform_config(platform, experiment_config, port, config_file_name, rest_process, prefixUrl): '''call set_cluster_metadata for specific platform''' print_normal('Setting {0} config...'.format(platform)) config_result, err_msg = None, None if platform == 'adl': - config_result, err_msg = set_adl_config(experiment_config, port, config_file_name) + config_result, err_msg = set_adl_config(experiment_config, port, config_file_name, prefixUrl) elif platform == 'kubeflow': - config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name) + config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name, prefixUrl) elif platform == 'frameworkcontroller': - config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name) + config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name, prefixUrl) else: raise Exception(ERROR_INFO % 'Unsupported platform!') exit(1) if config_result: - config_result, err_msg = set_shared_storage(experiment_config, port, config_file_name) + config_result, err_msg = set_shared_storage(experiment_config, port, config_file_name, prefixUrl) if config_result: print_normal('Successfully set {0} config!'.format(platform)) else: @@ -392,7 +392,8 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi # save experiment information Experiments().add_experiment(experiment_id, args.port, start_time, platform, - experiment_config.get('experimentName', 'N/A'), pid=rest_process.pid, logDir=log_dir) + experiment_config.get('experimentName', 'N/A') + , pid=rest_process.pid, logDir=log_dir, prefixUrl=args.url_prefix) # Deal with annotation if experiment_config.get('useAnnotation'): path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation') @@ -413,7 +414,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi experiment_config['searchSpace'] = '' # check rest server - running, _ = check_rest_server(args.port) + running, _ = check_rest_server(args.port, args.url_prefix) if running: print_normal('Successfully started Restful server!') else: @@ -427,17 +428,18 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi if config_version == 1 and mode != 'view': # set platform configuration set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\ - experiment_id, rest_process) + experiment_id, rest_process, args.url_prefix) # start a new experiment print_normal('Starting experiment...') # set debug configuration if mode != 'view' and experiment_config.get('debug') is None: experiment_config['debug'] = args.debug + print_normal(config_version) if config_version == 1: - response = set_experiment_v1(experiment_config, mode, args.port, experiment_id) + response = set_experiment_v1(experiment_config, mode, args.port, experiment_id, args.url_prefix) else: - response = set_experiment_v2(experiment_config, mode, args.port, experiment_id) + response = set_experiment_v2(experiment_config, mode, args.port, experiment_id, args.url_prefix) if response: if experiment_id is None: experiment_id = json.loads(response.text).get('experiment_id') @@ -537,6 +539,7 @@ def manage_stopped_experiment(args, mode): print_normal('{0} experiment {1}...'.format(mode, experiment_id)) experiment_config = Config(experiment_id, experiments_dict[args.id]['logDir']).get_config() experiments_config.update_experiment(args.id, 'port', args.port) + args.url_prefix = experiments_dict[args.id]['prefixUrl'] assert 'trainingService' in experiment_config or 'trainingServicePlatform' in experiment_config try: if 'trainingServicePlatform' in experiment_config: diff --git a/nni/tools/nnictl/nnictl_utils.py b/nni/tools/nnictl/nnictl_utils.py index 16942b74ef..8f61ca2c48 100644 --- a/nni/tools/nnictl/nnictl_utils.py +++ b/nni/tools/nnictl/nnictl_utils.py @@ -25,17 +25,17 @@ from .command_utils import check_output_command, kill_command from .ssh_utils import create_ssh_sftp_client, remove_remote_directory -def get_experiment_time(port): +def get_experiment_time(port, prefixUrl): '''get the startTime and endTime of an experiment''' - response = rest_get(experiment_url(port), REST_TIME_OUT) + response = rest_get(experiment_url(port, prefixUrl), REST_TIME_OUT) if response and check_response(response): content = json.loads(response.text) return content.get('startTime'), content.get('endTime') return None, None -def get_experiment_status(port): +def get_experiment_status(port, prefixUrl): '''get the status of an experiment''' - result, response = check_rest_server_quick(port) + result, response = check_rest_server_quick(port, prefixUrl) if result: return json.loads(response.text).get('status') return None @@ -202,7 +202,8 @@ def check_rest(args): experiments_config = Experiments() experiments_dict = experiments_config.get_all_experiments() rest_port = experiments_dict.get(get_config_filename(args)).get('port') - running, _ = check_rest_server_quick(rest_port) + prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') + running, _ = check_rest_server_quick(rest_port, prefix_url) if running: print_normal('Restful server is running...') else: @@ -245,13 +246,14 @@ def final_metric_data_cmp(lhs, rhs): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') + prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, response = check_rest_server_quick(rest_port) + running, response = check_rest_server_quick(rest_port, prefix_url) if running: - response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT) + response = rest_get(trial_jobs_url(rest_port, prefix_url), REST_TIME_OUT) if response and check_response(response): content = json.loads(response.text) if args.head: @@ -278,13 +280,14 @@ def trial_kill(args): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') + prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, _ = check_rest_server_quick(rest_port) + running, _ = check_rest_server_quick(rest_port, prefix_url) if running: - response = rest_delete(trial_job_id_url(rest_port, args.trial_id), REST_TIME_OUT) + response = rest_delete(trial_job_id_url(rest_port, args.trial_id, prefix_url), REST_TIME_OUT) if response and check_response(response): print(response.text) return True @@ -311,13 +314,14 @@ def list_experiment(args): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') + prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, _ = check_rest_server_quick(rest_port) + running, _ = check_rest_server_quick(rest_port, prefix_url) if running: - response = rest_get(experiment_url(rest_port), REST_TIME_OUT) + response = rest_get(experiment_url(rest_port, prefix_url), REST_TIME_OUT) if response and check_response(response): content = convert_time_stamp_to_date(json.loads(response.text)) print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':'))) @@ -333,7 +337,8 @@ def experiment_status(args): experiments_config = Experiments() experiments_dict = experiments_config.get_all_experiments() rest_port = experiments_dict.get(get_config_filename(args)).get('port') - result, response = check_rest_server_quick(rest_port) + prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') + result, response = check_rest_server_quick(rest_port, prefix_url) if not result: print_normal('Restful server is not running...') else: @@ -399,14 +404,15 @@ def log_trial(args): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') + prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') experiment_config = Config(experiment_id, experiments_dict.get(experiment_id).get('logDir')).get_config() if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, response = check_rest_server_quick(rest_port) + running, response = check_rest_server_quick(rest_port, prefix_url) if running: - response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT) + response = rest_get(trial_jobs_url(rest_port, prefix_url), REST_TIME_OUT) if response and check_response(response): content = json.loads(response.text) for trial in content: @@ -661,9 +667,9 @@ def show_experiment_info(): experiments_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'], \ get_time_interval(experiments_dict[key]['startTime'], experiments_dict[key]['endTime']))) print(TRIAL_MONITOR_HEAD) - running, response = check_rest_server_quick(experiments_dict[key]['port']) + running, response = check_rest_server_quick(experiments_dict[key]['port'], experiments_dict[key]['prefixUrl']) if running: - response = rest_get(trial_jobs_url(experiments_dict[key]['port']), REST_TIME_OUT) + response = rest_get(trial_jobs_url(experiments_dict[key]['port'], experiments_dict[key]['prefixUrl']), REST_TIME_OUT) if response and check_response(response): content = json.loads(response.text) for index, value in enumerate(content): @@ -672,7 +678,7 @@ def show_experiment_info(): content[index].get('endTime'), content[index].get('status'))) print(TRIAL_MONITOR_TAIL) -def set_monitor(auto_exit, time_interval, port=None, pid=None): +def set_monitor(auto_exit, time_interval, port=None, pid=None, prefixUrl=None): '''set the experiment monitor engine''' while True: try: @@ -683,7 +689,7 @@ def set_monitor(auto_exit, time_interval, port=None, pid=None): update_experiment() show_experiment_info() if auto_exit: - status = get_experiment_status(port) + status = get_experiment_status(port, prefixUrl) if status in ['DONE', 'ERROR', 'STOPPED']: print_normal('Experiment status is {0}.'.format(status)) print_normal('Stopping experiment...') @@ -724,20 +730,21 @@ def groupby_trial_id(intermediate_results): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') + prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, response = check_rest_server_quick(rest_port) + running, response = check_rest_server_quick(rest_port, prefix_url) if not running: print_error('Restful server is not running') return - response = rest_get(export_data_url(rest_port), 20) + response = rest_get(export_data_url(rest_port, prefix_url), 20) if response is not None and check_response(response): content = json.loads(response.text) if args.intermediate: - intermediate_results_response = rest_get(metric_data_url(rest_port), REST_TIME_OUT) + intermediate_results_response = rest_get(metric_data_url(rest_port, prefix_url), REST_TIME_OUT) if not intermediate_results_response or not check_response(intermediate_results_response): print_error('Error getting intermediate results.') return diff --git a/nni/tools/nnictl/rest_utils.py b/nni/tools/nnictl/rest_utils.py index e98c9a8392..b9bcf0f29c 100644 --- a/nni/tools/nnictl/rest_utils.py +++ b/nni/tools/nnictl/rest_utils.py @@ -49,11 +49,11 @@ def rest_delete(url, timeout, show_error=False): print_error(exception) return None -def check_rest_server(rest_port): +def check_rest_server(rest_port, prefixUrl): '''Check if restful server is ready''' retry_count = 20 for _ in range(retry_count): - response = rest_get(check_status_url(rest_port), REST_TIME_OUT) + response = rest_get(check_status_url(rest_port, prefixUrl), REST_TIME_OUT) if response: if response.status_code == 200: return True, response @@ -63,9 +63,9 @@ def check_rest_server(rest_port): time.sleep(1) return False, response -def check_rest_server_quick(rest_port): +def check_rest_server_quick(rest_port, prefixUrl): '''Check if restful server is ready, only check once''' - response = rest_get(check_status_url(rest_port), 5) + response = rest_get(check_status_url(rest_port, prefixUrl), 5) if response and response.status_code == 200: return True, response return False, None diff --git a/nni/tools/nnictl/updater.py b/nni/tools/nnictl/updater.py index e462562349..8e2dbc8152 100644 --- a/nni/tools/nnictl/updater.py +++ b/nni/tools/nnictl/updater.py @@ -62,13 +62,14 @@ def update_experiment_profile(args, key, value): experiments_config = Experiments() experiments_dict = experiments_config.get_all_experiments() rest_port = experiments_dict.get(get_config_filename(args)).get('port') - running, _ = check_rest_server_quick(rest_port) + prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') + running, _ = check_rest_server_quick(rest_port, prefix_url) if running: - response = rest_get(experiment_url(rest_port), REST_TIME_OUT) + response = rest_get(experiment_url(rest_port, prefix_url), REST_TIME_OUT) if response and check_response(response): experiment_profile = json.loads(response.text) experiment_profile['params'][key] = value - response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT) + response = rest_put(experiment_url(rest_port, prefix_url)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT) if response and check_response(response): return response else: @@ -121,11 +122,12 @@ def import_data(args): experiments_dict = Experiments().get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') + prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, _ = check_rest_server_quick(rest_port) + running, _ = check_rest_server_quick(rest_port, prefix_url) if not running: print_error('Restful server is not running') return @@ -141,9 +143,10 @@ def import_data_to_restful_server(args, content): '''call restful server to import data to the experiment''' experiments_dict = Experiments().get_all_experiments() rest_port = experiments_dict.get(get_config_filename(args)).get('port') - running, _ = check_rest_server_quick(rest_port) + prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') + running, _ = check_rest_server_quick(rest_port, prefix_url) if running: - response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT) + response = rest_post(import_data_url(rest_port, prefix_url), content, REST_TIME_OUT) if response and check_response(response): return response else: diff --git a/nni/tools/nnictl/url_utils.py b/nni/tools/nnictl/url_utils.py index 091fcb9de3..aa532a20e7 100644 --- a/nni/tools/nnictl/url_utils.py +++ b/nni/tools/nnictl/url_utils.py @@ -28,54 +28,51 @@ def path_validation(path): assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid." -def set_prefix_url(prefix): - '''set prefix url''' - path_validation(prefix) - global API_ROOT_URL - API_ROOT_URL = '{0}/{1}'.format(API_ROOT_URL, prefix) +def formatURLPath(path): + return '' if path is None else '/{0}'.format(path) -def metric_data_url(port): +def metric_data_url(port,prefix): '''get metric_data url''' - return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), METRIC_DATA_API) -def check_status_url(port): +def check_status_url(port,prefix): '''get check_status url''' - return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CHECK_STATUS_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), CHECK_STATUS_API) -def cluster_metadata_url(port): +def cluster_metadata_url(port,prefix): '''get cluster_metadata_url''' - return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CLUSTER_METADATA_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), CLUSTER_METADATA_API) -def import_data_url(port): +def import_data_url(port,prefix): '''get import_data_url''' - return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, IMPORT_DATA_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), IMPORT_DATA_API) -def experiment_url(port): +def experiment_url(port,prefix): '''get experiment_url''' - return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPERIMENT_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), EXPERIMENT_API) -def trial_jobs_url(port): +def trial_jobs_url(port,prefix): '''get trial_jobs url''' - return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TRIAL_JOBS_API) -def trial_job_id_url(port, job_id): +def trial_job_id_url(port, job_id,prefix): '''get trial_jobs with id url''' - return '{0}:{1}{2}{3}/{4}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API, job_id) + return '{0}:{1}{2}{3}{4}/{5}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TRIAL_JOBS_API, job_id) -def export_data_url(port): +def export_data_url(port,prefix): '''get export_data url''' - return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPORT_DATA_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), EXPORT_DATA_API) -def tensorboard_url(port): +def tensorboard_url(port,prefix): '''get tensorboard url''' - return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TENSORBOARD_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TENSORBOARD_API) def get_local_urls(port): From 5a71ce916b9298d7d3cda77b3fee24d4474f05e8 Mon Sep 17 00:00:00 2001 From: Hao Ni Date: Fri, 14 May 2021 14:15:46 +0800 Subject: [PATCH 3/5] nnictl prompt descriptions for web_ui_urls --- nni/tools/nnictl/launcher.py | 7 +++---- nni/tools/nnictl/url_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index 9fd1c56e0f..433a3b1499 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -16,7 +16,7 @@ import nni_node # pylint: disable=import-error from .launcher_utils import validate_all_content from .rest_utils import rest_put, rest_post, check_rest_server, check_response -from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, path_validation +from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, path_validation, formatURLPath from .config_utils import Config, Experiments from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \ detect_port, get_user @@ -435,7 +435,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi # set debug configuration if mode != 'view' and experiment_config.get('debug') is None: experiment_config['debug'] = args.debug - print_normal(config_version) if config_version == 1: response = set_experiment_v1(experiment_config, mode, args.port, experiment_id, args.url_prefix) else: @@ -452,9 +451,9 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi raise Exception(ERROR_INFO % 'Restful server stopped!') exit(1) if experiment_config.get('nniManagerIp'): - web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))] + web_ui_url_list = ['http://{0}:{1}{2}'.format(experiment_config['nniManagerIp'], str(args.port), formatURLPath(args.url_prefix))] else: - web_ui_url_list = get_local_urls(args.port) + web_ui_url_list = get_local_urls(args.port, args.url_prefix) Experiments().update_experiment(experiment_id, 'webuiUrl', web_ui_url_list) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) diff --git a/nni/tools/nnictl/url_utils.py b/nni/tools/nnictl/url_utils.py index aa532a20e7..e744ad723a 100644 --- a/nni/tools/nnictl/url_utils.py +++ b/nni/tools/nnictl/url_utils.py @@ -75,11 +75,11 @@ def tensorboard_url(port,prefix): return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TENSORBOARD_API) -def get_local_urls(port): +def get_local_urls(port,prefix): '''get urls of local machine''' url_list = [] for _, info in psutil.net_if_addrs().items(): for addr in info: if socket.AddressFamily.AF_INET == addr.family: - url_list.append('http://{}:{}'.format(addr.address, port)) + url_list.append('http://{0}:{1}{2}'.format(addr.address, port, formatURLPath(prefix))) return url_list From 18406310ef80848ecfdbb11831e6df5cbac94ec8 Mon Sep 17 00:00:00 2001 From: Hao Ni Date: Mon, 17 May 2021 14:27:51 +0800 Subject: [PATCH 4/5] override api_root_path --- nni/tools/nnictl/launcher.py | 79 ++++++++++--------- nni/tools/nnictl/nnictl_utils.py | 49 +++++------- nni/tools/nnictl/rest_utils.py | 8 +- nni/tools/nnictl/updater.py | 15 ++-- nni/tools/nnictl/url_utils.py | 44 +++++------ .../common/experimentStartupInfo.ts | 17 ++-- ts/nni_manager/rest_server/nniRestServer.ts | 5 +- 7 files changed, 110 insertions(+), 107 deletions(-) diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index 433a3b1499..ed2f298c6b 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -9,6 +9,7 @@ import random import time import tempfile +import re from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT from nni.experiment.config import ExperimentConfig, convert from nni.tools.annotation import expand_annotations, generate_search_space @@ -16,7 +17,7 @@ import nni_node # pylint: disable=import-error from .launcher_utils import validate_all_content from .rest_utils import rest_put, rest_post, check_rest_server, check_response -from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, path_validation, formatURLPath +from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, setPrefixUrl, formatURLPath from .config_utils import Config, Experiments from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \ detect_port, get_user @@ -82,7 +83,8 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log if foreground: cmds += ['--foreground', 'true'] if url_prefix: - path_validation(url_prefix) + _validate_prefix_path(url_prefix) + setPrefixUrl(url_prefix) cmds += ['--url_prefix', url_prefix] stdout_full_path, stderr_full_path = get_log_path(experiment_id) @@ -106,11 +108,11 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file) return process, int(start_time * 1000) -def set_trial_config(experiment_config, port, config_file_name, prefixUrl): +def set_trial_config(experiment_config, port, config_file_name): '''set trial configuration''' request_data = dict() request_data['trial_config'] = experiment_config['trial'] - response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(request_data), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT) if check_response(response): return True else: @@ -121,12 +123,12 @@ def set_trial_config(experiment_config, port, config_file_name, prefixUrl): fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':'))) return False -def set_adl_config(experiment_config, port, config_file_name, prefixUrl): +def set_adl_config(experiment_config, port, config_file_name): '''set adl configuration''' adl_config_data = dict() # hack for supporting v2 config, need refactor adl_config_data['adl_config'] = {} - response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(adl_config_data), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port), json.dumps(adl_config_data), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -136,11 +138,11 @@ def set_adl_config(experiment_config, port, config_file_name, prefixUrl): fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) return False, err_message set_V1_common_config(experiment_config, port, config_file_name) - result, message = setNNIManagerIp(experiment_config, port, config_file_name, prefixUrl) + result, message = setNNIManagerIp(experiment_config, port, config_file_name) if not result: return result, message #set trial_config - return set_trial_config(experiment_config, port, config_file_name, prefixUrl), None + return set_trial_config(experiment_config, port, config_file_name), None def validate_response(response, config_file_name): err_message = None @@ -154,7 +156,7 @@ def validate_response(response, config_file_name): exit(1) # hack to fix v1 version_check and log_collection bug, need refactor -def set_V1_common_config(experiment_config, port, config_file_name, prefixUrl): +def set_V1_common_config(experiment_config, port, config_file_name): version_check = True #debug mode should disable version check if experiment_config.get('debug') is not None: @@ -162,19 +164,19 @@ def set_V1_common_config(experiment_config, port, config_file_name, prefixUrl): #validate version check if experiment_config.get('versionCheck') is not None: version_check = experiment_config.get('versionCheck') - response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps({'version_check': version_check}), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port), json.dumps({'version_check': version_check}), REST_TIME_OUT) validate_response(response, config_file_name) if experiment_config.get('logCollection'): - response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps({'log_collection': experiment_config.get('logCollection')}), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port), json.dumps({'log_collection': experiment_config.get('logCollection')}), REST_TIME_OUT) validate_response(response, config_file_name) -def setNNIManagerIp(experiment_config, port, config_file_name, prefixUrl): +def setNNIManagerIp(experiment_config, port, config_file_name): '''set nniManagerIp''' if experiment_config.get('nniManagerIp') is None: return True, None ip_config_dict = dict() ip_config_dict['nni_manager_ip'] = {'nniManagerIp': experiment_config['nniManagerIp']} - response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(ip_config_dict), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -185,11 +187,11 @@ def setNNIManagerIp(experiment_config, port, config_file_name, prefixUrl): return False, err_message return True, None -def set_kubeflow_config(experiment_config, port, config_file_name, prefixUrl): +def set_kubeflow_config(experiment_config, port, config_file_name): '''set kubeflow configuration''' kubeflow_config_data = dict() kubeflow_config_data['kubeflow_config'] = experiment_config['kubeflowConfig'] - response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(kubeflow_config_data), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -198,18 +200,18 @@ def set_kubeflow_config(experiment_config, port, config_file_name, prefixUrl): with open(stderr_full_path, 'a+') as fout: fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) return False, err_message - set_V1_common_config(experiment_config, port, config_file_name, prefixUrl) - result, message = setNNIManagerIp(experiment_config, port, config_file_name, prefixUrl) + set_V1_common_config(experiment_config, port, config_file_name) + result, message = setNNIManagerIp(experiment_config, port, config_file_name) if not result: return result, message #set trial_config - return set_trial_config(experiment_config, port, config_file_name, prefixUrl), err_message + return set_trial_config(experiment_config, port, config_file_name), err_message -def set_frameworkcontroller_config(experiment_config, port, config_file_name, prefixUrl): +def set_frameworkcontroller_config(experiment_config, port, config_file_name): '''set kubeflow configuration''' frameworkcontroller_config_data = dict() frameworkcontroller_config_data['frameworkcontroller_config'] = experiment_config['frameworkcontrollerConfig'] - response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -218,16 +220,16 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name, pr with open(stderr_full_path, 'a+') as fout: fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) return False, err_message - set_V1_common_config(experiment_config, port, config_file_name, prefixUrl) - result, message = setNNIManagerIp(experiment_config, port, config_file_name, prefixUrl) + set_V1_common_config(experiment_config, port, config_file_name) + result, message = setNNIManagerIp(experiment_config, port, config_file_name) if not result: return result, message #set trial_config - return set_trial_config(experiment_config, port, config_file_name, prefixUrl), err_message + return set_trial_config(experiment_config, port, config_file_name), err_message -def set_shared_storage(experiment_config, port, config_file_name, prefixUrl): +def set_shared_storage(experiment_config, port, config_file_name): if 'sharedStorage' in experiment_config: - response = rest_put(cluster_metadata_url(port, prefixUrl), json.dumps({'shared_storage_config': experiment_config['sharedStorage']}), REST_TIME_OUT) + response = rest_put(cluster_metadata_url(port), json.dumps({'shared_storage_config': experiment_config['sharedStorage']}), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -238,7 +240,7 @@ def set_shared_storage(experiment_config, port, config_file_name, prefixUrl): return False, err_message return True, None -def set_experiment_v1(experiment_config, mode, port, config_file_name, prefixUrl): +def set_experiment_v1(experiment_config, mode, port, config_file_name): '''Call startExperiment (rest POST /experiment) with yaml file content''' request_data = dict() request_data['authorName'] = experiment_config['authorName'] @@ -298,7 +300,7 @@ def set_experiment_v1(experiment_config, mode, port, config_file_name, prefixUrl elif experiment_config['trainingServicePlatform'] == 'adl': request_data['clusterMetaData'].append( {'key': 'trial_config', 'value': experiment_config['trial']}) - response = rest_post(experiment_url(port, prefixUrl), json.dumps(request_data), REST_TIME_OUT, show_error=True) + response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True) if check_response(response): return response else: @@ -309,9 +311,9 @@ def set_experiment_v1(experiment_config, mode, port, config_file_name, prefixUrl print_error('Setting experiment error, error message is {}'.format(response.text)) return None -def set_experiment_v2(experiment_config, mode, port, config_file_name, prefixUrl): +def set_experiment_v2(experiment_config, mode, port, config_file_name): '''Call startExperiment (rest POST /experiment) with yaml file content''' - response = rest_post(experiment_url(port, prefixUrl), json.dumps(experiment_config), REST_TIME_OUT, show_error=True) + response = rest_post(experiment_url(port), json.dumps(experiment_config), REST_TIME_OUT, show_error=True) if check_response(response): return response else: @@ -322,21 +324,21 @@ def set_experiment_v2(experiment_config, mode, port, config_file_name, prefixUrl print_error('Setting experiment error, error message is {}'.format(response.text)) return None -def set_platform_config(platform, experiment_config, port, config_file_name, rest_process, prefixUrl): +def set_platform_config(platform, experiment_config, port, config_file_name, rest_process): '''call set_cluster_metadata for specific platform''' print_normal('Setting {0} config...'.format(platform)) config_result, err_msg = None, None if platform == 'adl': - config_result, err_msg = set_adl_config(experiment_config, port, config_file_name, prefixUrl) + config_result, err_msg = set_adl_config(experiment_config, port, config_file_name) elif platform == 'kubeflow': - config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name, prefixUrl) + config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name) elif platform == 'frameworkcontroller': - config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name, prefixUrl) + config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name) else: raise Exception(ERROR_INFO % 'Unsupported platform!') exit(1) if config_result: - config_result, err_msg = set_shared_storage(experiment_config, port, config_file_name, prefixUrl) + config_result, err_msg = set_shared_storage(experiment_config, port, config_file_name) if config_result: print_normal('Successfully set {0} config!'.format(platform)) else: @@ -414,7 +416,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi experiment_config['searchSpace'] = '' # check rest server - running, _ = check_rest_server(args.port, args.url_prefix) + running, _ = check_rest_server(args.port) if running: print_normal('Successfully started Restful server!') else: @@ -436,9 +438,9 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi if mode != 'view' and experiment_config.get('debug') is None: experiment_config['debug'] = args.debug if config_version == 1: - response = set_experiment_v1(experiment_config, mode, args.port, experiment_id, args.url_prefix) + response = set_experiment_v1(experiment_config, mode, args.port, experiment_id) else: - response = set_experiment_v2(experiment_config, mode, args.port, experiment_id, args.url_prefix) + response = set_experiment_v2(experiment_config, mode, args.port, experiment_id) if response: if experiment_id is None: experiment_id = json.loads(response.text).get('experiment_id') @@ -481,6 +483,9 @@ def _validate_v2(config, path): except Exception as e: print_error(f'Config V2 validation failed: {repr(e)}') +def _validate_prefix_path(path): + assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid." + def create_experiment(args): '''start a new experiment''' experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8)) diff --git a/nni/tools/nnictl/nnictl_utils.py b/nni/tools/nnictl/nnictl_utils.py index 8f61ca2c48..16942b74ef 100644 --- a/nni/tools/nnictl/nnictl_utils.py +++ b/nni/tools/nnictl/nnictl_utils.py @@ -25,17 +25,17 @@ from .command_utils import check_output_command, kill_command from .ssh_utils import create_ssh_sftp_client, remove_remote_directory -def get_experiment_time(port, prefixUrl): +def get_experiment_time(port): '''get the startTime and endTime of an experiment''' - response = rest_get(experiment_url(port, prefixUrl), REST_TIME_OUT) + response = rest_get(experiment_url(port), REST_TIME_OUT) if response and check_response(response): content = json.loads(response.text) return content.get('startTime'), content.get('endTime') return None, None -def get_experiment_status(port, prefixUrl): +def get_experiment_status(port): '''get the status of an experiment''' - result, response = check_rest_server_quick(port, prefixUrl) + result, response = check_rest_server_quick(port) if result: return json.loads(response.text).get('status') return None @@ -202,8 +202,7 @@ def check_rest(args): experiments_config = Experiments() experiments_dict = experiments_config.get_all_experiments() rest_port = experiments_dict.get(get_config_filename(args)).get('port') - prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') - running, _ = check_rest_server_quick(rest_port, prefix_url) + running, _ = check_rest_server_quick(rest_port) if running: print_normal('Restful server is running...') else: @@ -246,14 +245,13 @@ def final_metric_data_cmp(lhs, rhs): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') - prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, response = check_rest_server_quick(rest_port, prefix_url) + running, response = check_rest_server_quick(rest_port) if running: - response = rest_get(trial_jobs_url(rest_port, prefix_url), REST_TIME_OUT) + response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT) if response and check_response(response): content = json.loads(response.text) if args.head: @@ -280,14 +278,13 @@ def trial_kill(args): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') - prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, _ = check_rest_server_quick(rest_port, prefix_url) + running, _ = check_rest_server_quick(rest_port) if running: - response = rest_delete(trial_job_id_url(rest_port, args.trial_id, prefix_url), REST_TIME_OUT) + response = rest_delete(trial_job_id_url(rest_port, args.trial_id), REST_TIME_OUT) if response and check_response(response): print(response.text) return True @@ -314,14 +311,13 @@ def list_experiment(args): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') - prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, _ = check_rest_server_quick(rest_port, prefix_url) + running, _ = check_rest_server_quick(rest_port) if running: - response = rest_get(experiment_url(rest_port, prefix_url), REST_TIME_OUT) + response = rest_get(experiment_url(rest_port), REST_TIME_OUT) if response and check_response(response): content = convert_time_stamp_to_date(json.loads(response.text)) print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':'))) @@ -337,8 +333,7 @@ def experiment_status(args): experiments_config = Experiments() experiments_dict = experiments_config.get_all_experiments() rest_port = experiments_dict.get(get_config_filename(args)).get('port') - prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') - result, response = check_rest_server_quick(rest_port, prefix_url) + result, response = check_rest_server_quick(rest_port) if not result: print_normal('Restful server is not running...') else: @@ -404,15 +399,14 @@ def log_trial(args): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') - prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') experiment_config = Config(experiment_id, experiments_dict.get(experiment_id).get('logDir')).get_config() if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, response = check_rest_server_quick(rest_port, prefix_url) + running, response = check_rest_server_quick(rest_port) if running: - response = rest_get(trial_jobs_url(rest_port, prefix_url), REST_TIME_OUT) + response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT) if response and check_response(response): content = json.loads(response.text) for trial in content: @@ -667,9 +661,9 @@ def show_experiment_info(): experiments_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'], \ get_time_interval(experiments_dict[key]['startTime'], experiments_dict[key]['endTime']))) print(TRIAL_MONITOR_HEAD) - running, response = check_rest_server_quick(experiments_dict[key]['port'], experiments_dict[key]['prefixUrl']) + running, response = check_rest_server_quick(experiments_dict[key]['port']) if running: - response = rest_get(trial_jobs_url(experiments_dict[key]['port'], experiments_dict[key]['prefixUrl']), REST_TIME_OUT) + response = rest_get(trial_jobs_url(experiments_dict[key]['port']), REST_TIME_OUT) if response and check_response(response): content = json.loads(response.text) for index, value in enumerate(content): @@ -678,7 +672,7 @@ def show_experiment_info(): content[index].get('endTime'), content[index].get('status'))) print(TRIAL_MONITOR_TAIL) -def set_monitor(auto_exit, time_interval, port=None, pid=None, prefixUrl=None): +def set_monitor(auto_exit, time_interval, port=None, pid=None): '''set the experiment monitor engine''' while True: try: @@ -689,7 +683,7 @@ def set_monitor(auto_exit, time_interval, port=None, pid=None, prefixUrl=None): update_experiment() show_experiment_info() if auto_exit: - status = get_experiment_status(port, prefixUrl) + status = get_experiment_status(port) if status in ['DONE', 'ERROR', 'STOPPED']: print_normal('Experiment status is {0}.'.format(status)) print_normal('Stopping experiment...') @@ -730,21 +724,20 @@ def groupby_trial_id(intermediate_results): experiments_dict = experiments_config.get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') - prefix_url = experiments_dict.get(experiment_id).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, response = check_rest_server_quick(rest_port, prefix_url) + running, response = check_rest_server_quick(rest_port) if not running: print_error('Restful server is not running') return - response = rest_get(export_data_url(rest_port, prefix_url), 20) + response = rest_get(export_data_url(rest_port), 20) if response is not None and check_response(response): content = json.loads(response.text) if args.intermediate: - intermediate_results_response = rest_get(metric_data_url(rest_port, prefix_url), REST_TIME_OUT) + intermediate_results_response = rest_get(metric_data_url(rest_port), REST_TIME_OUT) if not intermediate_results_response or not check_response(intermediate_results_response): print_error('Error getting intermediate results.') return diff --git a/nni/tools/nnictl/rest_utils.py b/nni/tools/nnictl/rest_utils.py index b9bcf0f29c..e98c9a8392 100644 --- a/nni/tools/nnictl/rest_utils.py +++ b/nni/tools/nnictl/rest_utils.py @@ -49,11 +49,11 @@ def rest_delete(url, timeout, show_error=False): print_error(exception) return None -def check_rest_server(rest_port, prefixUrl): +def check_rest_server(rest_port): '''Check if restful server is ready''' retry_count = 20 for _ in range(retry_count): - response = rest_get(check_status_url(rest_port, prefixUrl), REST_TIME_OUT) + response = rest_get(check_status_url(rest_port), REST_TIME_OUT) if response: if response.status_code == 200: return True, response @@ -63,9 +63,9 @@ def check_rest_server(rest_port, prefixUrl): time.sleep(1) return False, response -def check_rest_server_quick(rest_port, prefixUrl): +def check_rest_server_quick(rest_port): '''Check if restful server is ready, only check once''' - response = rest_get(check_status_url(rest_port, prefixUrl), 5) + response = rest_get(check_status_url(rest_port), 5) if response and response.status_code == 200: return True, response return False, None diff --git a/nni/tools/nnictl/updater.py b/nni/tools/nnictl/updater.py index 8e2dbc8152..e462562349 100644 --- a/nni/tools/nnictl/updater.py +++ b/nni/tools/nnictl/updater.py @@ -62,14 +62,13 @@ def update_experiment_profile(args, key, value): experiments_config = Experiments() experiments_dict = experiments_config.get_all_experiments() rest_port = experiments_dict.get(get_config_filename(args)).get('port') - prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') - running, _ = check_rest_server_quick(rest_port, prefix_url) + running, _ = check_rest_server_quick(rest_port) if running: - response = rest_get(experiment_url(rest_port, prefix_url), REST_TIME_OUT) + response = rest_get(experiment_url(rest_port), REST_TIME_OUT) if response and check_response(response): experiment_profile = json.loads(response.text) experiment_profile['params'][key] = value - response = rest_put(experiment_url(rest_port, prefix_url)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT) + response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT) if response and check_response(response): return response else: @@ -122,12 +121,11 @@ def import_data(args): experiments_dict = Experiments().get_all_experiments() experiment_id = get_config_filename(args) rest_port = experiments_dict.get(experiment_id).get('port') - prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') rest_pid = experiments_dict.get(experiment_id).get('pid') if not detect_process(rest_pid): print_error('Experiment is not running...') return - running, _ = check_rest_server_quick(rest_port, prefix_url) + running, _ = check_rest_server_quick(rest_port) if not running: print_error('Restful server is not running') return @@ -143,10 +141,9 @@ def import_data_to_restful_server(args, content): '''call restful server to import data to the experiment''' experiments_dict = Experiments().get_all_experiments() rest_port = experiments_dict.get(get_config_filename(args)).get('port') - prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl') - running, _ = check_rest_server_quick(rest_port, prefix_url) + running, _ = check_rest_server_quick(rest_port) if running: - response = rest_post(import_data_url(rest_port, prefix_url), content, REST_TIME_OUT) + response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT) if response and check_response(response): return response else: diff --git a/nni/tools/nnictl/url_utils.py b/nni/tools/nnictl/url_utils.py index e744ad723a..71af16de68 100644 --- a/nni/tools/nnictl/url_utils.py +++ b/nni/tools/nnictl/url_utils.py @@ -3,7 +3,6 @@ import socket import psutil -import re BASE_URL = 'http://localhost' @@ -25,54 +24,55 @@ METRIC_DATA_API = '/metric-data' -def path_validation(path): - assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid." - def formatURLPath(path): return '' if path is None else '/{0}'.format(path) -def metric_data_url(port,prefix): +def setPrefixUrl(prefix_path): + global API_ROOT_URL + API_ROOT_URL = formatURLPath(prefix_path) + +def metric_data_url(port): '''get metric_data url''' - return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), METRIC_DATA_API) + return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API) -def check_status_url(port,prefix): +def check_status_url(port): '''get check_status url''' - return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), CHECK_STATUS_API) + return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CHECK_STATUS_API) -def cluster_metadata_url(port,prefix): +def cluster_metadata_url(port): '''get cluster_metadata_url''' - return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), CLUSTER_METADATA_API) + return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CLUSTER_METADATA_API) -def import_data_url(port,prefix): +def import_data_url(port): '''get import_data_url''' - return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), IMPORT_DATA_API) + return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, IMPORT_DATA_API) -def experiment_url(port,prefix): +def experiment_url(port): '''get experiment_url''' - return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), EXPERIMENT_API) + return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPERIMENT_API) -def trial_jobs_url(port,prefix): +def trial_jobs_url(port): '''get trial_jobs url''' - return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TRIAL_JOBS_API) + return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API) -def trial_job_id_url(port, job_id,prefix): +def trial_job_id_url(port, job_id): '''get trial_jobs with id url''' - return '{0}:{1}{2}{3}{4}/{5}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TRIAL_JOBS_API, job_id) + return '{0}:{1}{2}{3}/{4}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API, job_id) -def export_data_url(port,prefix): +def export_data_url(port): '''get export_data url''' - return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), EXPORT_DATA_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, EXPORT_DATA_API) -def tensorboard_url(port,prefix): +def tensorboard_url(port): '''get tensorboard url''' - return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TENSORBOARD_API) + return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TENSORBOARD_API) def get_local_urls(port,prefix): diff --git a/ts/nni_manager/common/experimentStartupInfo.ts b/ts/nni_manager/common/experimentStartupInfo.ts index c56afb9409..4dfb88bb19 100644 --- a/ts/nni_manager/common/experimentStartupInfo.ts +++ b/ts/nni_manager/common/experimentStartupInfo.ts @@ -10,6 +10,8 @@ import * as component from '../common/component'; @component.Singleton class ExperimentStartupInfo { + private readonly API_ROOT_URL: string = '/api/v1/nni'; + private experimentId: string = ''; private newExperiment: boolean = true; private basePort: number = -1; @@ -100,9 +102,14 @@ class ExperimentStartupInfo { return this.dispatcherPipe; } - public getUrlPrefix(): string { + public getAPIRootUrl(): string { assert(this.initialized); - return this.urlprefix; + if(this.urlprefix==''){ + return this.API_ROOT_URL; + } + else{ + return `/${this.urlprefix}`; + } } } @@ -140,11 +147,11 @@ function getDispatcherPipe(): string | null { return component.get(ExperimentStartupInfo).getDispatcherPipe(); } -function getUrlPrefix(): string { - return component.get(ExperimentStartupInfo).getUrlPrefix(); +function getAPIRootUrl(): string { + return component.get(ExperimentStartupInfo).getAPIRootUrl(); } export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getPlatform, getExperimentStartupInfo, - setExperimentStartupInfo, isReadonly, getDispatcherPipe, getUrlPrefix + setExperimentStartupInfo, isReadonly, getDispatcherPipe, getAPIRootUrl }; diff --git a/ts/nni_manager/rest_server/nniRestServer.ts b/ts/nni_manager/rest_server/nniRestServer.ts index b14c081bd5..cc8c016c94 100644 --- a/ts/nni_manager/rest_server/nniRestServer.ts +++ b/ts/nni_manager/rest_server/nniRestServer.ts @@ -10,7 +10,7 @@ import * as component from '../common/component'; import { RestServer } from '../common/restServer' import { getLogDir } from '../common/utils'; import { createRestHandler } from './restHandler'; -import { getUrlPrefix } from '../common/experimentStartupInfo'; +import { getAPIRootUrl } from '../common/experimentStartupInfo'; /** * NNI Main rest server, provides rest API to support @@ -20,14 +20,15 @@ import { getUrlPrefix } from '../common/experimentStartupInfo'; */ @component.Singleton export class NNIRestServer extends RestServer { - private readonly API_ROOT_URL: string = `/api/v1/nni/${getUrlPrefix()}`; private readonly LOGS_ROOT_URL: string = '/logs'; + protected API_ROOT_URL: string = '/api/v1/nni'; /** * constructor to provide NNIRestServer's own rest property, e.g. port */ constructor() { super(); + this.API_ROOT_URL = getAPIRootUrl(); } /** From 904f8568497a00ca81bb97e0477d800cc8b69886 Mon Sep 17 00:00:00 2001 From: Hao Ni Date: Mon, 17 May 2021 14:42:36 +0800 Subject: [PATCH 5/5] fix parameter issue --- nni/tools/nnictl/launcher.py | 2 +- ts/nni_manager/common/experimentStartupInfo.ts | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index ed2f298c6b..e4d6a0966a 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -430,7 +430,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi if config_version == 1 and mode != 'view': # set platform configuration set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\ - experiment_id, rest_process, args.url_prefix) + experiment_id, rest_process) # start a new experiment print_normal('Starting experiment...') diff --git a/ts/nni_manager/common/experimentStartupInfo.ts b/ts/nni_manager/common/experimentStartupInfo.ts index 4dfb88bb19..18d412869d 100644 --- a/ts/nni_manager/common/experimentStartupInfo.ts +++ b/ts/nni_manager/common/experimentStartupInfo.ts @@ -104,12 +104,7 @@ class ExperimentStartupInfo { public getAPIRootUrl(): string { assert(this.initialized); - if(this.urlprefix==''){ - return this.API_ROOT_URL; - } - else{ - return `/${this.urlprefix}`; - } + return this.urlprefix==''?this.API_ROOT_URL:`/${this.urlprefix}`; } }