diff --git a/nni/experiment/config/aml.py b/nni/experiment/config/aml.py index 9837ab38e8..2fd92e7e76 100644 --- a/nni/experiment/config/aml.py +++ b/nni/experiment/config/aml.py @@ -15,6 +15,7 @@ class AmlConfig(TrainingServiceConfig): workspace_name: str compute_target: str docker_image: str = 'msranni/nni:latest' + max_trial_number_per_gpu: int = 1 _validation_rules = { 'platform': lambda value: (value == 'aml', 'cannot be modified') diff --git a/nni/experiment/config/convert.py b/nni/experiment/config/convert.py index 6f36151e5a..621be6d1cd 100644 --- a/nni/experiment/config/convert.py +++ b/nni/experiment/config/convert.py @@ -134,7 +134,7 @@ def to_v2(v1) -> ExperimentConfig: _move_field(aml_config, ts, 'resourceGroup', 'resource_group') _move_field(aml_config, ts, 'workspaceName', 'workspace_name') _move_field(aml_config, ts, 'computeTarget', 'compute_target') - _deprecate(aml_config, v2, 'maxTrialNumPerGpu') + _move_field(aml_config, ts, 'maxTrialNumPerGpu', 'max_trial_number_per_gpu') _deprecate(aml_config, v2, 'useActiveGpu') assert not aml_config, aml_config diff --git a/ts/nni_manager/common/experimentConfig.ts b/ts/nni_manager/common/experimentConfig.ts index 6f3ff588eb..fff40c547c 100644 --- a/ts/nni_manager/common/experimentConfig.ts +++ b/ts/nni_manager/common/experimentConfig.ts @@ -65,6 +65,7 @@ export interface AmlConfig extends TrainingServiceConfig { workspaceName: string; computeTarget: string; dockerImage: string; + maxTrialNumberPerGpu: number; } /* Kubeflow */ diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index 6ca2e67c56..ca19c77560 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -9,15 +9,16 @@ import * as component from '../../../common/component'; import { getExperimentId } from '../../../common/experimentStartupInfo'; import { getLogger, Logger } from '../../../common/log'; import { getExperimentRootDir } from '../../../common/utils'; -import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; +import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig'; import { validateCodeDir } from '../../common/util'; import { AMLClient } from '../aml/amlClient'; -import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig'; +import { AMLEnvironmentInformation } from '../aml/amlConfig'; import { EnvironmentInformation, EnvironmentService } from '../environment'; import { EventEmitter } from "events"; import { AMLCommandChannel } from '../channels/amlCommandChannel'; import { SharedStorageService } from '../sharedStorage' +interface FlattenAmlConfig extends ExperimentConfig, AmlConfig { } /** * Collector AML jobs info from AML cluster, and update aml job status locally @@ -26,15 +27,16 @@ import { SharedStorageService } from '../sharedStorage' export class AMLEnvironmentService extends EnvironmentService { private readonly log: Logger = getLogger(); - public amlClusterConfig: AMLClusterConfig | undefined; - public amlTrialConfig: AMLTrialConfig | undefined; private experimentId: string; private experimentRootDir: string; + private config: FlattenAmlConfig; - constructor() { + constructor(config: ExperimentConfig) { super(); this.experimentId = getExperimentId(); this.experimentRootDir = getExperimentRootDir(); + this.config = flattenConfig(config, 'aml'); + validateCodeDir(this.config.trialCodeDirectory); } public get hasStorageService(): boolean { @@ -53,27 +55,6 @@ export class AMLEnvironmentService extends EnvironmentService { return 'aml'; } - public async config(key: string, value: string): Promise { - switch (key) { - case TrialConfigMetadataKey.AML_CLUSTER_CONFIG: - this.amlClusterConfig = JSON.parse(value); - break; - - case TrialConfigMetadataKey.TRIAL_CONFIG: { - if (this.amlClusterConfig === undefined) { - this.log.error('aml cluster config is not initialized'); - break; - } - this.amlTrialConfig = JSON.parse(value); - // Validate to make sure codeDir doesn't have too many files - await validateCodeDir(this.amlTrialConfig.codeDir); - break; - } - default: - this.log.debug(`AML not proccessed metadata key: '${key}', value: '${value}'`); - } - } - public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { environments.forEach(async (environment) => { const amlClient = (environment as AMLEnvironmentInformation).amlClient; @@ -107,12 +88,6 @@ export class AMLEnvironmentService extends EnvironmentService { } public async startEnvironment(environment: EnvironmentInformation): Promise { - if (this.amlClusterConfig === undefined) { - throw new Error('AML Cluster config is not initialized'); - } - if (this.amlTrialConfig === undefined) { - throw new Error('AML trial config is not initialized'); - } const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation; const environmentLocalTempFolder = path.join(this.experimentRootDir, "environment-temp"); if (!fs.existsSync(environmentLocalTempFolder)) { @@ -126,22 +101,24 @@ export class AMLEnvironmentService extends EnvironmentService { amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`; } amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`; - amlEnvironment.useActiveGpu = this.amlClusterConfig.useActiveGpu; - amlEnvironment.maxTrialNumberPerGpu = this.amlClusterConfig.maxTrialNumPerGpu; + amlEnvironment.useActiveGpu = !!this.config.deprecated.useActiveGpu; + amlEnvironment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu; await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' }); const amlClient = new AMLClient( - this.amlClusterConfig.subscriptionId, - this.amlClusterConfig.resourceGroup, - this.amlClusterConfig.workspaceName, + this.config.subscriptionId, + this.config.resourceGroup, + this.config.workspaceName, this.experimentId, - this.amlClusterConfig.computeTarget, - this.amlTrialConfig.image, + this.config.computeTarget, + this.config.dockerImage, 'nni_script.py', environmentLocalTempFolder ); amlEnvironment.id = await amlClient.submit(); + this.log.debug('aml: before getTrackingUrl'); amlEnvironment.trackingUrl = await amlClient.getTrackingUrl(); + this.log.debug('aml: after getTrackingUrl'); amlEnvironment.amlClient = amlClient; } diff --git a/ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts b/ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts index 1e3124fac7..1a32407bd4 100644 --- a/ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts +++ b/ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts @@ -13,7 +13,7 @@ export class EnvironmentServiceFactory { case 'remote': return new RemoteEnvironmentService(config); case 'aml': - return new AMLEnvironmentService(); + return new AMLEnvironmentService(config); case 'openpai': return new OpenPaiEnvironmentService(config); default: diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index d3d828a921..40e98360de 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -500,7 +500,7 @@ class TrialDispatcher implements TrainingService { const reuseMode = Array.isArray(this.config.trainingService) || (this.config.trainingService as any).reuseMode; if ( 0 === environment.runningTrialCount && - !reuseMode && + reuseMode === false && environment.assignedTrialCount > 0 ) { if (environment.environmentService === undefined) {