-
Notifications
You must be signed in to change notification settings - Fork 1.8k
feature: export experiment results #2706
Changes from 12 commits
1362b00
a128bc0
e862c1a
d3e03a7
7ac572b
8379c0f
194679e
3b23df0
ff3eea9
8d9713f
dea1c23
4316393
7c7dba8
5ef17ff
74827a5
6ff03d0
a58a588
ca195da
135b7b2
7d115a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,14 +9,15 @@ | |
import re | ||
import shutil | ||
import subprocess | ||
from functools import reduce | ||
from datetime import datetime, timezone | ||
from pathlib import Path | ||
from subprocess import Popen | ||
from pyhdfs import HdfsClient | ||
from nni.package_utils import get_nni_installation_path | ||
from nni_annotation import expand_annotations | ||
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response | ||
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url | ||
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url, metric_data_url | ||
from .config_utils import Config, Experiments | ||
from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \ | ||
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT | ||
|
@@ -681,45 +682,67 @@ def monitor_experiment(args): | |
set_monitor(False, args.time) | ||
|
||
def export_trials_data(args): | ||
'''export experiment metadata to csv | ||
'''export experiment metadata and intermediate results to json or csv | ||
''' | ||
def groupby_trial_id(intermediate_results): | ||
sorted(intermediate_results, key=lambda x: x['timestamp']) | ||
groupby = dict() | ||
for content in intermediate_results: | ||
groupby.setdefault(content['trialJobId'], []).append(content['data'][2:-2]) | ||
return groupby | ||
|
||
def trans_intermediate_dict(record): | ||
return {'intermediate': '[' + str(reduce(lambda x, y: x + ',' + y, record)) + ']'} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't know the content of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the content of record? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have not tested intermediate results with dict metric. I will do it this afternoon. |
||
|
||
nni_config = Config(get_config_filename(args)) | ||
rest_port = nni_config.get_config('restServerPort') | ||
rest_pid = nni_config.get_config('restServerPid') | ||
|
||
if not detect_process(rest_pid): | ||
print_error('Experiment is not running...') | ||
return | ||
running, response = check_rest_server_quick(rest_port) | ||
if running: | ||
response = rest_get(export_data_url(rest_port), 20) | ||
if response is not None and check_response(response): | ||
if args.type == 'json': | ||
with open(args.path, 'w') as file: | ||
file.write(response.text) | ||
elif args.type == 'csv': | ||
content = json.loads(response.text) | ||
trial_records = [] | ||
for record in content: | ||
record_value = json.loads(record['value']) | ||
if not isinstance(record_value, (float, int)): | ||
formated_record = {**record['parameter'], **record_value, **{'id': record['id']}} | ||
else: | ||
formated_record = {**record['parameter'], **{'reward': record_value, 'id': record['id']}} | ||
trial_records.append(formated_record) | ||
if not trial_records: | ||
print_error('No trial results collected! Please check your trial log...') | ||
exit(0) | ||
with open(args.path, 'w', newline='') as file: | ||
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in trial_records])) | ||
writer.writeheader() | ||
writer.writerows(trial_records) | ||
else: | ||
print_error('Unknown type: %s' % args.type) | ||
exit(1) | ||
if not running: | ||
print_error('Restful server is not Running') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Running -> running There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
return | ||
response = rest_get(export_data_url(rest_port), 20) | ||
if response is not None and check_response(response): | ||
content = json.loads(response.text) | ||
if args.intermediate: | ||
intermediate_results = rest_get(metric_data_url(rest_port), REST_TIME_OUT) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in my point, use |
||
if not intermediate_results or not check_response(intermediate_results): | ||
print_error('Error getting intermediate results.') | ||
return | ||
intermediate_results = groupby_trial_id(json.loads(intermediate_results.text)) | ||
for record in content: | ||
record['intermediate'] = intermediate_results[record['id']] | ||
if args.type == 'json': | ||
with open(args.path, 'w') as file: | ||
file.write(json.dumps(content)) | ||
elif args.type == 'csv': | ||
trial_records = [] | ||
for record in content: | ||
formated_record = dict() | ||
if args.intermediate: | ||
formated_record.update({**trans_intermediate_dict(record['intermediate'])}) | ||
tabVersion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
record_value = json.loads(record['value']) | ||
if not isinstance(record_value, (float, int)): | ||
formated_record.update({**record['parameter'], **record_value, **{'id': record['id']}}) | ||
else: | ||
formated_record.update({**record['parameter'], **{'reward': record_value, 'id': record['id']}}) | ||
trial_records.append(formated_record) | ||
if not trial_records: | ||
print_error('No trial results collected! Please check your trial log...') | ||
exit(0) | ||
with open(args.path, 'w', newline='') as file: | ||
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in trial_records])) | ||
writer.writeheader() | ||
writer.writerows(trial_records) | ||
else: | ||
print_error('Export failed...') | ||
print_error('Unknown type: %s' % args.type) | ||
return | ||
else: | ||
print_error('Restful server is not Running') | ||
print_error('Export failed...') | ||
|
||
def search_space_auto_gen(args): | ||
'''dry run trial code to generate search space file''' | ||
|
@@ -736,3 +759,4 @@ def search_space_auto_gen(args): | |
print_warning('Expected search space file \'{}\' generated, but not found.'.format(file_path)) | ||
else: | ||
print_normal('Generate search space done: \'{}\'.'.format(file_path)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are intermediate results included
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed