diff --git a/nni/tools/nnictl/config_utils.py b/nni/tools/nnictl/config_utils.py index 5cef7e2449..154b1d8ccb 100644 --- a/nni/tools/nnictl/config_utils.py +++ b/nni/tools/nnictl/config_utils.py @@ -108,7 +108,7 @@ def __init__(self, home_dir=NNI_HOME_DIR): self.experiments = self.read_file() def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED', - tag=[], pid=None, webuiUrl=[], logDir=''): + tag=[], pid=None, webuiUrl=[], logDir='', prefixUrl=None): '''set {key:value} pairs to self.experiment''' with self.lock: self.experiments = self.read_file() @@ -124,6 +124,7 @@ def add_experiment(self, expId, port, startTime, platform, experiment_name, endT self.experiments[expId]['pid'] = pid self.experiments[expId]['webuiUrl'] = webuiUrl self.experiments[expId]['logDir'] = str(logDir) + self.experiments[expId]['prefixUrl'] = prefixUrl self.write_file() def update_experiment(self, expId, key, value): diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index 7b2d20f9b8..e4d6a0966a 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -9,6 +9,7 @@ import random import time import tempfile +import re from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT from nni.experiment.config import ExperimentConfig, convert from nni.tools.annotation import expand_annotations, generate_search_space @@ -16,7 +17,7 @@ import nni_node # pylint: disable=import-error from .launcher_utils import validate_all_content from .rest_utils import rest_put, rest_post, check_rest_server, check_response -from .url_utils import cluster_metadata_url, experiment_url, get_local_urls +from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, setPrefixUrl, formatURLPath from .config_utils import Config, Experiments from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \ detect_port, get_user @@ -43,7 +44,7 @@ def print_log_content(config_file_name): print_normal(' Stderr:') print(check_output_command(stderr_full_path)) -def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None): +def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None, url_prefix=None): '''Run nni manager process''' if detect_port(port): print_error('Port %s is used by another process, please reset the port!\n' \ @@ -81,6 +82,11 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log cmds += ['--log_level', log_level] if foreground: cmds += ['--foreground', 'true'] + if url_prefix: + _validate_prefix_path(url_prefix) + setPrefixUrl(url_prefix) + cmds += ['--url_prefix', url_prefix] + stdout_full_path, stderr_full_path = get_log_path(experiment_id) with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: start_time = time.time() @@ -384,11 +390,12 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi platform = experiment_config['trainingService']['platform'] rest_process, start_time = start_rest_server(args.port, platform, \ - mode, experiment_id, foreground, log_dir, log_level) + mode, experiment_id, foreground, log_dir, log_level, args.url_prefix) # save experiment information Experiments().add_experiment(experiment_id, args.port, start_time, platform, - experiment_config.get('experimentName', 'N/A'), pid=rest_process.pid, logDir=log_dir) + experiment_config.get('experimentName', 'N/A') + , pid=rest_process.pid, logDir=log_dir, prefixUrl=args.url_prefix) # Deal with annotation if experiment_config.get('useAnnotation'): path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation') @@ -446,9 +453,9 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi raise Exception(ERROR_INFO % 'Restful server stopped!') exit(1) if experiment_config.get('nniManagerIp'): - web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))] + web_ui_url_list = ['http://{0}:{1}{2}'.format(experiment_config['nniManagerIp'], str(args.port), formatURLPath(args.url_prefix))] else: - web_ui_url_list = get_local_urls(args.port) + web_ui_url_list = get_local_urls(args.port, args.url_prefix) Experiments().update_experiment(experiment_id, 'webuiUrl', web_ui_url_list) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) @@ -476,6 +483,9 @@ def _validate_v2(config, path): except Exception as e: print_error(f'Config V2 validation failed: {repr(e)}') +def _validate_prefix_path(path): + assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid." + def create_experiment(args): '''start a new experiment''' experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8)) @@ -533,6 +543,7 @@ def manage_stopped_experiment(args, mode): print_normal('{0} experiment {1}...'.format(mode, experiment_id)) experiment_config = Config(experiment_id, experiments_dict[args.id]['logDir']).get_config() experiments_config.update_experiment(args.id, 'port', args.port) + args.url_prefix = experiments_dict[args.id]['prefixUrl'] assert 'trainingService' in experiment_config or 'trainingServicePlatform' in experiment_config try: if 'trainingServicePlatform' in experiment_config: diff --git a/nni/tools/nnictl/nnictl.py b/nni/tools/nnictl/nnictl.py index fd8697337e..d4b55d506a 100644 --- a/nni/tools/nnictl/nnictl.py +++ b/nni/tools/nnictl/nnictl.py @@ -54,6 +54,7 @@ def parse_args(): parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file') parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', type=int, help='the port of restful server') parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode') + parser_start.add_argument('--url_prefix', '-u', dest='url_prefix', help=' set prefix url') parser_start.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal') parser_start.set_defaults(func=create_experiment) diff --git a/nni/tools/nnictl/url_utils.py b/nni/tools/nnictl/url_utils.py index 59a28837a6..71af16de68 100644 --- a/nni/tools/nnictl/url_utils.py +++ b/nni/tools/nnictl/url_utils.py @@ -24,6 +24,13 @@ METRIC_DATA_API = '/metric-data' +def formatURLPath(path): + return '' if path is None else '/{0}'.format(path) + +def setPrefixUrl(prefix_path): + global API_ROOT_URL + API_ROOT_URL = formatURLPath(prefix_path) + def metric_data_url(port): '''get metric_data url''' return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API) @@ -60,7 +67,7 @@ def trial_job_id_url(port, job_id): def export_data_url(port): '''get export_data url''' - return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPORT_DATA_API) + return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, EXPORT_DATA_API) def tensorboard_url(port): @@ -68,11 +75,11 @@ def tensorboard_url(port): return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TENSORBOARD_API) -def get_local_urls(port): +def get_local_urls(port,prefix): '''get urls of local machine''' url_list = [] for _, info in psutil.net_if_addrs().items(): for addr in info: if socket.AddressFamily.AF_INET == addr.family: - url_list.append('http://{}:{}'.format(addr.address, port)) + url_list.append('http://{0}:{1}{2}'.format(addr.address, port, formatURLPath(prefix))) return url_list diff --git a/ts/nni_manager/common/experimentStartupInfo.ts b/ts/nni_manager/common/experimentStartupInfo.ts index 5316abd26e..18d412869d 100644 --- a/ts/nni_manager/common/experimentStartupInfo.ts +++ b/ts/nni_manager/common/experimentStartupInfo.ts @@ -10,6 +10,8 @@ import * as component from '../common/component'; @component.Singleton class ExperimentStartupInfo { + private readonly API_ROOT_URL: string = '/api/v1/nni'; + private experimentId: string = ''; private newExperiment: boolean = true; private basePort: number = -1; @@ -19,8 +21,9 @@ class ExperimentStartupInfo { private readonly: boolean = false; private dispatcherPipe: string | null = null; private platform: string = ''; + private urlprefix: string = ''; - public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string): void { + public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void { assert(!this.initialized); assert(experimentId.trim().length > 0); this.newExperiment = newExperiment; @@ -46,6 +49,10 @@ class ExperimentStartupInfo { if (dispatcherPipe != undefined && dispatcherPipe.length > 0) { this.dispatcherPipe = dispatcherPipe; } + + if(urlprefix != undefined && urlprefix.length > 0){ + this.urlprefix = urlprefix; + } } public getExperimentId(): string { @@ -94,6 +101,11 @@ class ExperimentStartupInfo { assert(this.initialized); return this.dispatcherPipe; } + + public getAPIRootUrl(): string { + assert(this.initialized); + return this.urlprefix==''?this.API_ROOT_URL:`/${this.urlprefix}`; + } } function getExperimentId(): string { @@ -117,9 +129,9 @@ function getExperimentStartupInfo(): ExperimentStartupInfo { } function setExperimentStartupInfo( - newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string): void { + newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void { component.get(ExperimentStartupInfo) - .setStartupInfo(newExperiment, experimentId, basePort, platform, logDir, logLevel, readonly, dispatcherPipe); + .setStartupInfo(newExperiment, experimentId, basePort, platform, logDir, logLevel, readonly, dispatcherPipe, urlprefix); } function isReadonly(): boolean { @@ -130,7 +142,11 @@ function getDispatcherPipe(): string | null { return component.get(ExperimentStartupInfo).getDispatcherPipe(); } +function getAPIRootUrl(): string { + return component.get(ExperimentStartupInfo).getAPIRootUrl(); +} + export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getPlatform, getExperimentStartupInfo, - setExperimentStartupInfo, isReadonly, getDispatcherPipe + setExperimentStartupInfo, isReadonly, getDispatcherPipe, getAPIRootUrl }; diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index 3b6c9bae90..c104c3a190 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -25,9 +25,9 @@ import { NNIRestServer } from './rest_server/nniRestServer'; function initStartupInfo( startExpMode: string, experimentId: string, basePort: number, platform: string, - logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string): void { + logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string, urlprefix: string): void { const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW); - setExperimentStartupInfo(createNew, experimentId, basePort, platform, logDirectory, experimentLogLevel, readonly, dispatcherPipe); + setExperimentStartupInfo(createNew, experimentId, basePort, platform, logDirectory, experimentLogLevel, readonly, dispatcherPipe, urlprefix); } async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise { @@ -122,7 +122,9 @@ const readonly = readonlyArg.toLowerCase() == 'true' ? true : false; const dispatcherPipe: string = parseArg(['--dispatcher_pipe']); -initStartupInfo(startMode, experimentId, port, mode, logDir, logLevel, readonly, dispatcherPipe); +const urlPrefix: string = parseArg(['--url_prefix']); + +initStartupInfo(startMode, experimentId, port, mode, logDir, logLevel, readonly, dispatcherPipe, urlPrefix); mkDirP(getLogDir()) .then(async () => { diff --git a/ts/nni_manager/rest_server/nniRestServer.ts b/ts/nni_manager/rest_server/nniRestServer.ts index d9e185773b..cc8c016c94 100644 --- a/ts/nni_manager/rest_server/nniRestServer.ts +++ b/ts/nni_manager/rest_server/nniRestServer.ts @@ -10,6 +10,7 @@ import * as component from '../common/component'; import { RestServer } from '../common/restServer' import { getLogDir } from '../common/utils'; import { createRestHandler } from './restHandler'; +import { getAPIRootUrl } from '../common/experimentStartupInfo'; /** * NNI Main rest server, provides rest API to support @@ -19,14 +20,15 @@ import { createRestHandler } from './restHandler'; */ @component.Singleton export class NNIRestServer extends RestServer { - private readonly API_ROOT_URL: string = '/api/v1/nni'; private readonly LOGS_ROOT_URL: string = '/logs'; + protected API_ROOT_URL: string = '/api/v1/nni'; /** * constructor to provide NNIRestServer's own rest property, e.g. port */ constructor() { super(); + this.API_ROOT_URL = getAPIRootUrl(); } /**