Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
AML config v2 (#3552)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhe-lz authored Apr 21, 2021
1 parent 7fd0776 commit dc54f4a
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 42 deletions.
1 change: 1 addition & 0 deletions nni/experiment/config/aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class AmlConfig(TrainingServiceConfig):
workspace_name: str
compute_target: str
docker_image: str = 'msranni/nni:latest'
max_trial_number_per_gpu: int = 1

_validation_rules = {
'platform': lambda value: (value == 'aml', 'cannot be modified')
Expand Down
2 changes: 1 addition & 1 deletion nni/experiment/config/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def to_v2(v1) -> ExperimentConfig:
_move_field(aml_config, ts, 'resourceGroup', 'resource_group')
_move_field(aml_config, ts, 'workspaceName', 'workspace_name')
_move_field(aml_config, ts, 'computeTarget', 'compute_target')
_deprecate(aml_config, v2, 'maxTrialNumPerGpu')
_move_field(aml_config, ts, 'maxTrialNumPerGpu', 'max_trial_number_per_gpu')
_deprecate(aml_config, v2, 'useActiveGpu')
assert not aml_config, aml_config

Expand Down
1 change: 1 addition & 0 deletions ts/nni_manager/common/experimentConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ export interface AmlConfig extends TrainingServiceConfig {
workspaceName: string;
computeTarget: string;
dockerImage: string;
maxTrialNumberPerGpu: number;
}

/* Kubeflow */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { getExperimentRootDir } from '../../../common/utils';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig';
import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient';
import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig';
import { AMLEnvironmentInformation } from '../aml/amlConfig';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { EventEmitter } from "events";
import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { SharedStorageService } from '../sharedStorage'

interface FlattenAmlConfig extends ExperimentConfig, AmlConfig { }

/**
* Collector AML jobs info from AML cluster, and update aml job status locally
Expand All @@ -26,15 +27,16 @@ import { SharedStorageService } from '../sharedStorage'
export class AMLEnvironmentService extends EnvironmentService {

private readonly log: Logger = getLogger();
public amlClusterConfig: AMLClusterConfig | undefined;
public amlTrialConfig: AMLTrialConfig | undefined;
private experimentId: string;
private experimentRootDir: string;
private config: FlattenAmlConfig;

constructor() {
constructor(config: ExperimentConfig) {
super();
this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir();
this.config = flattenConfig(config, 'aml');
validateCodeDir(this.config.trialCodeDirectory);
}

public get hasStorageService(): boolean {
Expand All @@ -53,27 +55,6 @@ export class AMLEnvironmentService extends EnvironmentService {
return 'aml';
}

public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.AML_CLUSTER_CONFIG:
this.amlClusterConfig = <AMLClusterConfig>JSON.parse(value);
break;

case TrialConfigMetadataKey.TRIAL_CONFIG: {
if (this.amlClusterConfig === undefined) {
this.log.error('aml cluster config is not initialized');
break;
}
this.amlTrialConfig = <AMLTrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.amlTrialConfig.codeDir);
break;
}
default:
this.log.debug(`AML not proccessed metadata key: '${key}', value: '${value}'`);
}
}

public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
environments.forEach(async (environment) => {
const amlClient = (environment as AMLEnvironmentInformation).amlClient;
Expand Down Expand Up @@ -107,12 +88,6 @@ export class AMLEnvironmentService extends EnvironmentService {
}

public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.amlClusterConfig === undefined) {
throw new Error('AML Cluster config is not initialized');
}
if (this.amlTrialConfig === undefined) {
throw new Error('AML trial config is not initialized');
}
const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
const environmentLocalTempFolder = path.join(this.experimentRootDir, "environment-temp");
if (!fs.existsSync(environmentLocalTempFolder)) {
Expand All @@ -126,22 +101,24 @@ export class AMLEnvironmentService extends EnvironmentService {
amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`;
}
amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`;
amlEnvironment.useActiveGpu = this.amlClusterConfig.useActiveGpu;
amlEnvironment.maxTrialNumberPerGpu = this.amlClusterConfig.maxTrialNumPerGpu;
amlEnvironment.useActiveGpu = !!this.config.deprecated.useActiveGpu;
amlEnvironment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu;

await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' });
const amlClient = new AMLClient(
this.amlClusterConfig.subscriptionId,
this.amlClusterConfig.resourceGroup,
this.amlClusterConfig.workspaceName,
this.config.subscriptionId,
this.config.resourceGroup,
this.config.workspaceName,
this.experimentId,
this.amlClusterConfig.computeTarget,
this.amlTrialConfig.image,
this.config.computeTarget,
this.config.dockerImage,
'nni_script.py',
environmentLocalTempFolder
);
amlEnvironment.id = await amlClient.submit();
this.log.debug('aml: before getTrackingUrl');
amlEnvironment.trackingUrl = await amlClient.getTrackingUrl();
this.log.debug('aml: after getTrackingUrl');
amlEnvironment.amlClient = amlClient;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export class EnvironmentServiceFactory {
case 'remote':
return new RemoteEnvironmentService(config);
case 'aml':
return new AMLEnvironmentService();
return new AMLEnvironmentService(config);
case 'openpai':
return new OpenPaiEnvironmentService(config);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class TrialDispatcher implements TrainingService {
const reuseMode = Array.isArray(this.config.trainingService) || (this.config.trainingService as any).reuseMode;
if (
0 === environment.runningTrialCount &&
!reuseMode &&
reuseMode === false &&
environment.assignedTrialCount > 0
) {
if (environment.environmentService === undefined) {
Expand Down

0 comments on commit dc54f4a

Please sign in to comment.