Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix config v2 bugs #3540

Merged
merged 4 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nni/tools/nnictl/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def add_experiment(self, expId, port, startTime, platform, experiment_name, endT
self.experiments[expId]['tag'] = tag
self.experiments[expId]['pid'] = pid
self.experiments[expId]['webuiUrl'] = webuiUrl
self.experiments[expId]['logDir'] = logDir
self.experiments[expId]['logDir'] = str(logDir)
self.write_file()

def update_experiment(self, expId, key, value):
Expand Down
45 changes: 30 additions & 15 deletions nni/tools/nnictl/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,21 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
kill_command(rest_process.pid)
print_normal('Stopping experiment...')

def _validate_v1(config, path):
try:
validate_all_content(config, path)
except Exception as e:
print_error(f'Config V1 validation failed: {repr(e)}')
exit(1)

def _validate_v2(config, path):
base_path = Path(path).parent
try:
conf = ExperimentConfig(_base_path=base_path, **config)
return conf.json()
except Exception as e:
print_error(f'Config V2 validation failed: {repr(e)}')

def create_experiment(args):
'''start a new experiment'''
experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
Expand All @@ -420,23 +435,23 @@ def create_experiment(args):
exit(1)
config_yml = get_yml_content(config_path)

try:
config = ExperimentConfig(_base_path=Path(config_path).parent, **config_yml)
config_v2 = config.json()
except Exception as error_v2:
print_warning('Validation with V2 schema failed. Trying to convert from V1 format...')
try:
validate_all_content(config_yml, config_path)
except Exception as error_v1:
print_error(f'Convert from v1 format failed: {repr(error_v1)}')
print_error(f'Config in v2 format validation failed: {repr(error_v2)}')
exit(1)
from nni.experiment.config import convert
config_v2 = convert.to_v2(config_yml).json()
if 'trainingServicePlatform' in config_yml:
_validate_v1(config_yml, config_path)
platform = config_yml['trainingServicePlatform']
if platform in k8s_training_services:
schema = 1
config_v1 = config_yml
else:
schema = 2
from nni.experiment.config import convert
config_v2 = convert.to_v2(config_yml).json()
else:
config_v2 = _validate_v2(config_yml, config_path)
schema = 2

try:
if getattr(config_v2['trainingService'], 'platform', None) in k8s_training_services:
launch_experiment(args, config_yml, 'new', experiment_id, 1)
if schema == 1:
launch_experiment(args, config_v1, 'new', experiment_id, 1)
else:
launch_experiment(args, config_v2, 'new', experiment_id, 2)
except Exception as exception:
Expand Down
35 changes: 0 additions & 35 deletions nni/tools/nnictl/nnictl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import traceback
from datetime import datetime, timezone
from subprocess import Popen
from pyhdfs import HdfsClient
from nni.tools.annotation import expand_annotations
import nni_node # pylint: disable=import-error
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
Expand Down Expand Up @@ -501,30 +500,6 @@ def remote_clean(machine_list, experiment_id=None):
print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir))
remove_remote_directory(sftp, remote_dir)

def hdfs_clean(host, user_name, output_dir, experiment_id=None):
'''clean up hdfs data'''
hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5)
if experiment_id:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments', experiment_id])
else:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments'])
print_normal('removing folder {0} in hdfs'.format(full_path))
hdfs_client.delete(full_path, recursive=True)
if output_dir:
pattern = re.compile('hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?')
match_result = pattern.match(output_dir)
if match_result:
output_host = match_result.group('host')
output_dir = match_result.group('baseDir')
#check if the host is valid
if output_host != host:
print_warning('The host in {0} is not consistent with {1}'.format(output_dir, host))
else:
if experiment_id:
output_dir = output_dir + '/' + experiment_id
print_normal('removing folder {0} in hdfs'.format(output_dir))
hdfs_client.delete(output_dir, recursive=True)

def experiment_clean(args):
'''clean up the experiment data'''
experiment_id_list = []
Expand Down Expand Up @@ -556,11 +531,6 @@ def experiment_clean(args):
if platform == 'remote':
machine_list = experiment_config.get('machineList')
remote_clean(machine_list, experiment_id)
elif platform == 'pai':
host = experiment_config.get('paiConfig').get('host')
user_name = experiment_config.get('paiConfig').get('userName')
output_dir = experiment_config.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, experiment_id)
elif platform != 'local':
# TODO: support all platforms
print_warning('platform {0} clean up not supported yet.'.format(platform))
Expand Down Expand Up @@ -632,11 +602,6 @@ def platform_clean(args):
if platform == 'remote':
machine_list = config_content.get('machineList')
remote_clean(machine_list)
elif platform == 'pai':
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir)
print_normal('Done.')

def experiment_list(args):
Expand Down
11 changes: 7 additions & 4 deletions ts/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,15 @@ class NNIManager implements Manager {
return this.dataStore.getTrialJob(trialJobId);
}

public async setClusterMetadata(_key: string, _value: string): Promise<void> {
throw new Error('Calling removed API setClusterMetadata');
public async setClusterMetadata(key: string, value: string): Promise<void> {
while (this.trainingService === undefined) {
await delay(1000);
}
this.trainingService.setClusterMetadata(key, value);
}

public getClusterMetadata(_key: string): Promise<string> {
throw new Error('Calling removed API getClusterMetadata');
public getClusterMetadata(key: string): Promise<string> {
return this.trainingService.getClusterMetadata(key);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need check if this.trainingService === undefined ?

}

public async getTrialJobStatistics(): Promise<TrialJobStatistics[]> {
Expand Down
4 changes: 4 additions & 0 deletions ts/nni_manager/training_service/reusable/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ export class EnvironmentInformation {

export abstract class EnvironmentService {

public async init(): Promise<void> {
return;
}

public abstract get hasStorageService(): boolean;
public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>;
public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
private readonly environmentExecutorManagerMap: Map<string, ExecutorManager>;
private readonly remoteMachineMetaOccupiedMap: Map<RemoteMachineConfig, boolean>;
private readonly log: Logger;
private sshConnectionPromises: any[];
private sshConnectionPromises: Promise<void[]>;
private experimentRootDir: string;
private remoteExperimentRootDir: string = "";
private experimentId: string;
Expand All @@ -39,7 +39,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
this.environmentExecutorManagerMap = new Map<string, ExecutorManager>();
this.machineExecutorManagerMap = new Map<RemoteMachineConfig, ExecutorManager>();
this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>();
this.sshConnectionPromises = [];
this.experimentRootDir = getExperimentRootDir();
this.experimentId = getExperimentId();
this.log = getLogger();
Expand All @@ -50,9 +49,18 @@ export class RemoteEnvironmentService extends EnvironmentService {
throw new Error(`codeDir ${this.config.trialCodeDirectory} is not a directory`);
}

this.sshConnectionPromises = this.config.machineList.map(
this.sshConnectionPromises = Promise.all(this.config.machineList.map(
machine => this.initRemoteMachineOnConnected(machine)
);
));
}

public async init(): Promise<void> {
await this.sshConnectionPromises;
this.log.info('ssh connection initialized!');
Array.from(this.machineExecutorManagerMap.keys()).forEach(rmMeta => {
// initialize remoteMachineMetaOccupiedMap, false means not occupied
this.remoteMachineMetaOccupiedMap.set(rmMeta, false);
});
}

public get prefetchedEnvironmentCount(): number {
Expand Down Expand Up @@ -204,16 +212,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
}

public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.sshConnectionPromises.length > 0) {
await Promise.all(this.sshConnectionPromises);
this.log.info('ssh connection initialized!');
// set sshConnectionPromises to [] to avoid log information duplicated
this.sshConnectionPromises = [];
Array.from(this.machineExecutorManagerMap.keys()).forEach(rmMeta => {
// initialize remoteMachineMetaOccupiedMap, false means not occupied
this.remoteMachineMetaOccupiedMap.set(rmMeta, false);
});
}
const remoteEnvironment: RemoteMachineEnvironmentInformation = environment as RemoteMachineEnvironmentInformation;
remoteEnvironment.status = 'WAITING';
// schedule machine for environment, generate command
Expand Down
5 changes: 3 additions & 2 deletions ts/nni_manager/training_service/reusable/trialDispatcher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ class TrialDispatcher implements TrainingService {
this.environmentServiceList.push(env);
}

// FIXME: max?
this.environmentMaintenceLoopInterval = Math.max(
...this.environmentServiceList.map((env) => env.environmentMaintenceLoopInterval)
);
Expand Down Expand Up @@ -211,6 +210,7 @@ class TrialDispatcher implements TrainingService {
}

public async run(): Promise<void> {
await Promise.all(this.environmentServiceList.map(env => env.init()));
for(const environmentService of this.environmentServiceList) {

const runnerSettings: RunnerSettings = new RunnerSettings();
Expand Down Expand Up @@ -497,9 +497,10 @@ class TrialDispatcher implements TrainingService {
liveEnvironmentsCount++;
if (environment.status === "RUNNING" && environment.isRunnerReady) {
// if environment is not reusable and used, stop and not count as idle;
const reuseMode = Array.isArray(this.config.trainingService) || (this.config.trainingService as any).reuseMode;
if (
0 === environment.runningTrialCount &&
!(this.config as any).reuseMode &&
!reuseMode &&
environment.assignedTrialCount > 0
) {
if (environment.environmentService === undefined) {
Expand Down
10 changes: 2 additions & 8 deletions ts/webui/src/components/overview/count/EditExperimentParam.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,7 @@ export const EditExperimentParam = (): any => {
}
if (isMaxDuration) {
const maxDura = JSON.parse(editInputVal);
if (unit === 'm') {
newProfile.params[field] = maxDura * 60;
} else if (unit === 'h') {
newProfile.params[field] = maxDura * 3600;
} else {
newProfile.params[field] = maxDura * 24 * 60 * 60;
}
newProfile.params[field] = `${maxDura}${unit}`;
} else {
newProfile.params[field] = parseInt(editInputVal, 10);
}
Expand Down Expand Up @@ -162,7 +156,7 @@ export const EditExperimentParam = (): any => {
<EditExpeParamContext.Consumer>
{(value): React.ReactNode => {
let editClassName = '';
if (value.field === 'maxExecDuration') {
if (value.field === 'maxExperimentDuration') {
editClassName = isShowPencil ? 'noEditDuration' : 'editDuration';
}
return (
Expand Down
2 changes: 1 addition & 1 deletion ts/webui/src/components/overview/count/ExpDuration.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export const ExpDuration = (): any => (
<EditExpeParamContext.Provider
value={{
editType: CONTROLTYPE[0],
field: 'maxExecDuration',
field: 'maxExperimentDuration',
title: 'Max duration',
maxExecDuration: maxExecDurationStr,
maxTrialNum: EXPERIMENT.maxTrialNumber,
Expand Down
2 changes: 1 addition & 1 deletion ts/webui/src/components/overview/count/TrialCount.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ export const TrialCount = (): any => {
<EditExpeParamContext.Provider
value={{
title: MAX_TRIAL_NUMBERS,
field: 'maxTrialNum',
field: 'maxTrialNumber',
editType: CONTROLTYPE[1],
maxExecDuration: '',
maxTrialNum: EXPERIMENT.maxTrialNumber,
Expand Down
4 changes: 2 additions & 2 deletions ts/webui/src/components/slideNav/TrialConfigPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ class TrialConfigPanel extends React.Component<LogDrawerProps, LogDrawerState> {
<AppContext.Consumer>
{(value): React.ReactNode => {
const unit = value.maxDurationUnit;
profile.params.maxExecDuration = `${convertTimeAsUnit(
profile.params.maxExperimentDuration = `${convertTimeAsUnit(
unit,
profile.params.maxExecDuration
profile.params.maxExperimentDuration
)}${unit}`;
const showProfile = JSON.stringify(profile, filter, 2);
return (
Expand Down
5 changes: 4 additions & 1 deletion ts/webui/src/static/experimentConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ export interface ExperimentConfig {

const timeUnits = { d: 24 * 3600, h: 3600, m: 60, s: 1 };

export function toSeconds(time: string): number {
export function toSeconds(time: string | number): number {
if (typeof time === 'number') {
return time;
}
for (const [unit, factor] of Object.entries(timeUnits)) {
if (time.endsWith(unit)) {
const digits = time.slice(0, -1);
Expand Down