diff --git a/ts/nni_manager/config/aml/amlUtil.py b/ts/nni_manager/config/aml/amlUtil.py index ca2a5e51a0..a5b2a6bc2d 100644 --- a/ts/nni_manager/config/aml/amlUtil.py +++ b/ts/nni_manager/config/aml/amlUtil.py @@ -44,6 +44,16 @@ print('tracking_url:' + run.get_portal_url()) elif line == 'stop': run.cancel() + loop_count = 0 + status = run.get_status() + # wait until the run is canceled + while status != 'Canceled': + if loop_count > 5: + print('stop_result:failed') + exit(0) + loop_count += 1 + time.sleep(500) + print('stop_result:success') exit(0) elif line == 'receive': print('receive:' + json.dumps(run.get_metrics())) diff --git a/ts/nni_manager/training_service/reusable/aml/amlClient.ts b/ts/nni_manager/training_service/reusable/aml/amlClient.ts index c1e10e7954..a93eb767ad 100644 --- a/ts/nni_manager/training_service/reusable/aml/amlClient.ts +++ b/ts/nni_manager/training_service/reusable/aml/amlClient.ts @@ -60,11 +60,21 @@ export class AMLClient { return deferred.promise; } - public stop(): void { + public stop(): Promise { if (this.pythonShellClient === undefined) { throw Error('python shell client not initialized!'); } + const deferred: Deferred = new Deferred(); this.pythonShellClient.send('stop'); + this.pythonShellClient.on('message', (result: any) => { + const stopResult = this.parseContent('stop_result', result); + if (stopResult === 'success') { + deferred.resolve(true); + } else if (stopResult === 'failed') { + deferred.resolve(false); + } + }); + return deferred.promise; } public getTrackingUrl(): Promise { diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index 89468543e6..5c1755948d 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -127,6 +127,11 @@ export class AMLEnvironmentService extends EnvironmentService { if (!amlClient) { throw new Error('AML client not initialized!'); } - amlClient.stop(); + const result = await amlClient.stop(); + if (result) { + this.log.info(`Stop aml run ${environment.id} success!`); + } else { + this.log.info(`Stop aml run ${environment.id} failed!`); + } } } diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index 6dbbe2eb97..c29674e943 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -299,6 +299,16 @@ class TrialDispatcher implements TrainingService { public async setClusterMetadata(_key: string, _value: string): Promise { return; } public async getClusterMetadata(_key: string): Promise { return ""; } + public async stopEnvironment(environment: EnvironmentInformation): Promise { + if (environment.environmentService === undefined) { + throw new Error(`${environment.id} do not have environmentService!`); + } + this.log.info(`stopping environment ${environment.id}...`); + await environment.environmentService.stopEnvironment(environment); + this.log.info(`stopped environment ${environment.id}.`); + return; + } + public async cleanUp(): Promise { if (this.commandEmitter === undefined) { throw new Error(`TrialDispatcher: commandEmitter shouldn't be undefined in cleanUp.`); @@ -306,16 +316,12 @@ class TrialDispatcher implements TrainingService { this.stopping = true; this.shouldUpdateTrials = true; const environments = [...this.environments.values()]; - + + const stopEnvironmentPromise: Promise[] = []; for (let index = 0; index < environments.length; index++) { - const environment = environments[index]; - this.log.info(`stopping environment ${environment.id}...`); - if (environment.environmentService === undefined) { - throw new Error(`${environment.id} do not have environmentService!`); - } - await environment.environmentService.stopEnvironment(environment); - this.log.info(`stopped environment ${environment.id}.`); + stopEnvironmentPromise.push(this.stopEnvironment(environments[index])); } + await Promise.all(stopEnvironmentPromise); this.commandEmitter.off("command", this.handleCommand); for (const commandChannel of this.commandChannelSet) { await commandChannel.stop(); @@ -650,6 +656,10 @@ class TrialDispatcher implements TrainingService { } private async requestEnvironment(environmentService: EnvironmentService): Promise { + if (this.stopping) { + this.log.info(`Experiment is stopping, stop creating new environment`); + return; + } const envId = uniqueString(5); const envName = `nni_exp_${this.experimentId}_env_${envId}`; const environment = environmentService.createEnvironmentInformation(envId, envName);