Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix experiment import bug and add it cases: experiment import #2878

Merged
merged 4 commits into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en_US/TrialExample/SklearnExamples.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ It is easy to use NNI in your scikit-learn code, there are only a few steps.
"kernel": {"_type":"choice","_value":["linear", "rbf", "poly", "sigmoid"]},
"degree": {"_type":"choice","_value":[1, 2, 3, 4]},
"gamma": {"_type":"uniform","_value":[0.01, 0.1]},
"coef0 ": {"_type":"uniform","_value":[0.01, 0.1]}
"coef0": {"_type":"uniform","_value":[0.01, 0.1]}
}
```

Expand Down
2 changes: 1 addition & 1 deletion examples/trials/sklearn/classification/search_space.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
"kernel": {"_type":"choice","_value":["linear", "rbf", "poly", "sigmoid"]},
"degree": {"_type":"choice","_value":[1, 2, 3, 4]},
"gamma": {"_type":"uniform","_value":[0.01, 0.1]},
"coef0 ": {"_type":"uniform","_value":[0.01, 0.1]}
"coef0": {"_type":"uniform","_value":[0.01, 0.1]}
}
1 change: 1 addition & 0 deletions src/nni_manager/common/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ abstract class Manager {
public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
public abstract importData(data: string): Promise<void>;
public abstract getImportedData(): Promise<string[]>;
public abstract exportData(): Promise<string>;

public abstract addCustomizedTrialJob(hyperParams: string): Promise<number>;
Expand Down
4 changes: 4 additions & 0 deletions src/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class NNIManager implements Manager {
return this.dataStore.storeTrialJobEvent('IMPORT_DATA', '', data);
}

public getImportedData(): Promise<string[]> {
return this.dataStore.getImportedData();
}

public async exportData(): Promise<string> {
return this.dataStore.exportTrialHpConfigs();
}
Expand Down
11 changes: 11 additions & 0 deletions src/nni_manager/rest_server/restHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class NNIRestHandler {
this.getExperimentProfile(router);
this.updateExperimentProfile(router);
this.importData(router);
this.getImportedData(router);
this.startExperiment(router);
this.getTrialJobStatistics(router);
this.setClusterMetaData(router);
Expand Down Expand Up @@ -143,6 +144,16 @@ class NNIRestHandler {
});
}

private getImportedData(router: Router): void {
router.get('/experiment/imported-data', (req: Request, res: Response) => {
this.nniManager.getImportedData().then((importedData: string[]) => {
res.send(JSON.stringify(importedData));
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}

private startExperiment(router: Router): void {
router.post('/experiment', expressJoi(ValidationSchemas.STARTEXPERIMENT), (req: Request, res: Response) => {
if (isNewExperiment()) {
Expand Down
4 changes: 4 additions & 0 deletions src/nni_manager/rest_server/test/mockedNNIManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ export class MockedNNIManager extends Manager {
public importData(data: string): Promise<void> {
return Promise.resolve();
}
public getImportedData(): Promise<string[]> {
const ret: string[] = ["1", "2"];
return Promise.resolve(ret);
}
public async exportData(): Promise<string> {
const ret: string = '';
return Promise.resolve(ret);
Expand Down
1 change: 1 addition & 0 deletions src/sdk/pynni/nni/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def handle_import_data(self, data):
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for entry in data:
entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value'])
entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data)

Expand Down
7 changes: 7 additions & 0 deletions test/config/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ testCases:
validator:
class: ExportValidator

- name: experiment-import
configFile: test/config/nnictl_experiment/sklearn-classification.yml
validator:
class: ImportValidator
kwargs:
import_data_file_path: config/nnictl_experiment/test_import.json

- name: nnicli
configFile: test/config/examples/sklearn-regression.yml
config:
Expand Down
23 changes: 23 additions & 0 deletions test/config/nnictl_experiment/sklearn-classification.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 4
trialConcurrency: 2
searchSpacePath: ../../../examples/trials/sklearn/classification/search_space.json

tuner:
builtinTunerName: TPE
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ../../../examples/trials/sklearn/classification
command: python3 main.py
gpuNum: 0

useAnnotation: false
multiPhase: false
multiThread: false

trainingServicePlatform: local
4 changes: 4 additions & 0 deletions test/config/nnictl_experiment/test_import.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
{"parameter": {"C": 0.15940134774738896, "kernel": "sigmoid", "degree": 3, "gamma": 0.07295826917955316, "coef0": 0.0978204758732429}, "value": 0.6},
{"parameter": {"C": 0.5556430724708544, "kernel": "linear", "degree": 3, "gamma": 0.04957496655414671, "coef0": 0.08520868779907687}, "value": 0.7}
]
1 change: 1 addition & 0 deletions test/nni_test/nnitest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
STATUS_URL = API_ROOT_URL + '/check-status'
TRIAL_JOBS_URL = API_ROOT_URL + '/trial-jobs'
METRICS_URL = API_ROOT_URL + '/metric-data'
GET_IMPORTED_DATA_URL = API_ROOT_URL + '/experiment/imported-data'

def read_last_line(file_name):
'''read last line of a file and return None if file not found'''
Expand Down
14 changes: 13 additions & 1 deletion test/nni_test/nnitest/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import json
import requests
from nnicli import Experiment
from utils import METRICS_URL
from nni_cmd.updater import load_search_space
from utils import METRICS_URL, GET_IMPORTED_DATA_URL


class ITValidator:
Expand All @@ -33,6 +34,17 @@ def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
print('\n\n')
remove('report.json')

class ImportValidator(ITValidator):
def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
exp_id = osp.split(experiment_dir)[-1]
import_data_file_path = kwargs.get('import_data_file_path')
proc = subprocess.run(['nnictl', 'experiment', 'import', exp_id, '-f', import_data_file_path])
assert proc.returncode == 0, \
'`nnictl experiment import {0} -f {1}` failed with code {2}'.format(exp_id, import_data_file_path, proc.returncode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any other means to verify the data is actually imported successfully besides checking the return code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll view the code again to find if there is a way to check the result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to check the imported data, add an api api/v1/nni/experiment/imported-data to get the imported data. I don't know if it is appropriate to do so. Maybe users also have demand to view the data they have imported?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is OK to call api/v1/nni/experiment/imported-data to verify

imported_data = requests.get(GET_IMPORTED_DATA_URL).json()
origin_data = load_search_space(import_data_file_path).replace(' ', '')
assert origin_data in imported_data

class MetricsValidator(ITValidator):
def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
self.check_metrics(nni_source_dir, **kwargs)
Expand Down
16 changes: 14 additions & 2 deletions tools/nni_cmd/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .url_utils import experiment_url, import_data_url
from .config_utils import Config
from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import get_experiment_port, get_config_filename
from .nnictl_utils import get_experiment_port, get_config_filename, detect_process
from .launcher_utils import parse_time
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA

Expand Down Expand Up @@ -115,7 +115,19 @@ def import_data(args):
validate_file(args.filename)
validate_dispatcher(args)
content = load_search_space(args.filename)
args.port = get_experiment_port(args)

nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, _ = check_rest_server_quick(rest_port)
if not running:
print_error('Restful server is not running')
return

args.port = rest_port
if args.port is not None:
if import_data_to_restful_server(args, content):
pass
Expand Down