Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

update nnicli #2713

Merged
merged 24 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en_US/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import sys
sys.path.insert(0, os.path.abspath('../../src/sdk/pynni'))
sys.path.insert(1, os.path.abspath('../../src/sdk/pycli'))


# -- Project information ---------------------------------------------------
Expand Down
41 changes: 41 additions & 0 deletions docs/en_US/nnicli_ref.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# NNI Client

NNI client is a python API of `nnictl`, which implements the most common used commands. Users can use this API to control their experiments, collect experiment results and conduct advanced analyses based on experiment results in python code directly instead of using command line. Here is an example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

common used -> commonly used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


```
import nnicli as nc

nc.start_experiment('nni/examples/trials/mnist-pytorch/config.yml', port=9090) # start an experiment

nc.set_endpoint('http://localhost:9090') # set the experiment's endpoint, i.e., the url of Web UI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this API used for? after an experiment is created, the endpoint is fixed, right? why we need to set it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, need to set this endpoint in order to query restful APIs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to automatically detect this endpoint? because it is counter intuitive that user is required to set endpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with you. Maybe we can get the endpoint when user use start_experiment. But if user wants to conduct some operation on other experiments, they still need to set the endpoint by their own.


print(nc.version()) # check the version of nni
print(nc.get_experiment_status()) # get the experiment's status

print(nc.get_job_statistics()) # get the trial job information
print(nc.list_trial_jobs()) # get information for all trial jobs

nc.stop_nni(port=9090) # stop the experiment
```

## References

```eval_rst
.. autofunction:: nnicli.start_experiment
.. autofunction:: nnicli.set_endpoint
.. autofunction:: nnicli.resume_experiment
.. autofunction:: nnicli.view_experiment
.. autofunction:: nnicli.update_searchspace
.. autofunction:: nnicli.update_concurrency
.. autofunction:: nnicli.update_duration
.. autofunction:: nnicli.update_trailnum
.. autofunction:: nnicli.stop_experiment
.. autofunction:: nnicli.version
.. autofunction:: nnicli.get_experiment_status
.. autofunction:: nnicli.get_experiment_profile
.. autofunction:: nnicli.get_trial_job
.. autofunction:: nnicli.list_trial_jobs
.. autofunction:: nnicli.get_job_statistics
.. autofunction:: nnicli.get_job_metrics
.. autofunction:: nnicli.export_data
```
3 changes: 2 additions & 1 deletion docs/en_US/sdk_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Python API Reference

Auto Tune <autotune_ref>
NAS <NAS/NasReference>
Compression Utilities <Compressor/CompressionReference>
Compression Utilities <Compressor/CompressionReference>
NNI Client <nnicli_ref>
226 changes: 205 additions & 21 deletions src/sdk/pycli/nnicli/nni_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import nnicli as nc

nc.start_nni('../../../../examples/trials/mnist/config.yml')
nc.start_experiment('../../../../examples/trials/mnist-pytorch/config.yml')

nc.set_endpoint('http://localhost:8080')

Expand All @@ -17,7 +17,7 @@
print(nc.get_job_statistics())
print(nc.list_trial_jobs())

nc.stop_nni()
nc.stop_experiment()

"""

Expand All @@ -27,9 +27,16 @@
import requests

__all__ = [
'start_nni',
'stop_nni',
'start_experiment',
'set_endpoint',
'stop_experiment',
'resume_experiment',
'view_experiment',
'update_searchspace',
'update_concurrency',
'update_duration',
'update_trailnum',
'stop_experiment',
'version',
'get_experiment_status',
'get_experiment_profile',
Expand All @@ -53,8 +60,14 @@
_api_endpoint = None

def set_endpoint(endpoint):
"""set endpoint of nni rest server for nnicli, for example:
http://localhost:8080
"""
Set endpoint of nni rest server for nnicli, i.e., the url of Web UI.
Everytime you want to change experiment, call this function first.

Parameters
----------
endpoint: str
the endpoint of nni rest server for nnicli
"""
global _api_endpoint
_api_endpoint = endpoint
Expand Down Expand Up @@ -92,48 +105,219 @@ def _create_process(cmd):
print(output.decode('utf-8').strip())
return process.returncode

def start_nni(config_file):
"""start nni experiment with specified configuration file"""
def start_experiment(config_file, port=None, debug=False):
"""
Start an experiment with specified configuration file.

Parameters
----------
config_file: str
path to the config file
port: int
the port of restful server, bigger than 1024
debug: boolean
set debug mode
"""
cmd = 'nnictl create --config {}'.format(config_file).split(' ')
if port:
cmd += '--port {}'.format(port).split(' ')
if debug:
cmd += ['--debug']
if _create_process(cmd) != 0:
raise RuntimeError('Failed to start experiment.')

def resume_experiment(exp_id, port=None, debug=False):
"""
Resume a stopped experiment with specified experiment id

Parameters
----------
exp_id: str
experiment id
port: int
the port of restful server, bigger than 1024
debug: boolean
set debug mode
"""
cmd = 'nnictl resume {}'.format(exp_id).split(' ')
if port:
cmd += '--port {}'.format(port).split(' ')
if debug:
cmd += ['--debug']
if _create_process(cmd) != 0:
raise RuntimeError('Failed to resume experiment.')

def view_experiment(exp_id, port=None):
"""
View a stopped experiment with specified experiment id

Parameters
----------
exp_id: str
experiment id
port: int
the port of restful server, bigger than 1024
"""
cmd = 'nnictl view {}'.format(exp_id).split(' ')
if port:
cmd += '--port {}'.format(port).split(' ')
if _create_process(cmd) != 0:
raise RuntimeError('Failed to start nni.')
raise RuntimeError('Failed to view experiment.')

def stop_nni():
"""stop nni experiment"""
cmd = 'nnictl stop'.split(' ')
def update_searchspace(filename, exp_id=None):
"""
Update an experiment's search space

Parameters
----------
filename: str
path to the searchspace file
exp_id: str
experiment id
"""
if not exp_id:
cmd = 'nnictl update searchspace --filename {}'.format(filename).split(' ')
else:
cmd = 'nnictl update searchspace {} --filename {}'.format(exp_id, filename).split(' ')
if _create_process(cmd) != 0:
raise RuntimeError('Failed to update searchspace.')

def update_concurrency(value, exp_id=None):
"""
Update an experiment's concurrency

Parameters
----------
value: int
new concurrency value
exp_id: str
experiment id
"""
if not exp_id:
cmd = 'nnictl update concurrency --value {}'.format(value).split(' ')
else:
cmd = 'nnictl update concurrency {} --value {}'.format(exp_id, value).split(' ')
if _create_process(cmd) != 0:
raise RuntimeError('Failed to update concurrency.')

def update_duration(value, exp_id=None):
"""
Update an experiment's duration

Parameters
----------
value: str
SUFFIX may be 's' for seconds (the default), 'm' for minutes, 'h' for hours or 'd' for days. e.g., '1m', '2h'
exp_id: str
experiment id
"""
if not exp_id:
cmd = 'nnictl update duration --value {}'.format(value).split(' ')
else:
cmd = 'nnictl update duration {} --value {}'.format(exp_id, value).split(' ')
if _create_process(cmd) != 0:
raise RuntimeError('Failed to stop nni.')
raise RuntimeError('Failed to update duration.')

def update_trailnum(value, exp_id=None):
"""
Update an experiment's maxtrialnum

Parameters
----------
value: int
new trailnum value
exp_id: str
experiment id
"""
if not exp_id:
cmd = 'nnictl update trialnum --value {}'.format(value).split(' ')
else:
cmd = 'nnictl update trialnum {} --value {}'.format(exp_id, value).split(' ')
if _create_process(cmd) != 0:
raise RuntimeError('Failed to update trailnum.')

def stop_experiment(exp_id=None, port=None, stop_all=False):
"""Stop an experiment.

Parameters
----------
exp_id: str
experiment id
port: int
the port of restful server
stop_all: boolean
if set to True, all the experiments will be stopped

Note that if stop_all is set to true, exp_id and port will be ignored. Otherwise
exp_id and port must correspond to the same experiment if they are both set.
"""
if stop_all:
cmd = 'nnictl stop --all'.split(' ')
else:
cmd = 'nnictl stop'.split(' ')
if exp_id:
cmd += [exp_id]
if port:
cmd += '--port {}'.format(port).split(' ')
if _create_process(cmd) != 0:
raise RuntimeError('Failed to stop experiment.')

def version():
"""return version of nni"""
"""
Return version of nni.
"""
return _nni_rest_get(VERSION_PATH, 'text')

def get_experiment_status():
"""return experiment status as a dict"""
"""
Return experiment status as a dict.
"""
return _nni_rest_get(STATUS_PATH)

def get_experiment_profile():
"""return experiment profile as a dict"""
"""
Return experiment profile as a dict.
"""
return _nni_rest_get(EXPERIMENT_PATH)

def get_trial_job(trial_job_id):
"""return trial job information as a dict"""
"""
Return trial job information as a dict.

Parameters
----------
trial_job_id: str
trial id
"""
assert trial_job_id is not None
return _nni_rest_get(os.path.join(TRIAL_JOBS_PATH, trial_job_id))

def list_trial_jobs():
"""return information for all trial jobs as a list"""
"""
Return information for all trial jobs as a list.
"""
return _nni_rest_get(TRIAL_JOBS_PATH)

def get_job_statistics():
"""return trial job statistics information as a dict"""
"""
Return trial job statistics information as a dict.
"""
return _nni_rest_get(JOB_STATISTICS_PATH)

def get_job_metrics(trial_job_id=None):
"""return trial job metrics"""
"""
Return trial job metrics.

Parameters
----------
trial_job_id: str
trial id. if this parameter is None, all trail jobs' metrics will be returned.
"""
api_path = METRICS_PATH if trial_job_id is None else os.path.join(METRICS_PATH, trial_job_id)
return _nni_rest_get(api_path)

def export_data():
"""return exported information for all trial jobs"""
"""
Return exported information for all trial jobs.
"""
return _nni_rest_get(EXPORT_DATA_PATH)
4 changes: 2 additions & 2 deletions test/config/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ testCases:
config:
maxTrialNum: 4
trialConcurrency: 4
launchCommand: python3 -c 'import nnicli as nc; nc.start_nni("$configFile")'
stopCommand: python3 -c 'import nnicli as nc; nc.stop_nni()'
launchCommand: python3 -c 'import nnicli as nc; nc.start_experiment("$configFile")'
stopCommand: python3 -c 'import nnicli as nc; nc.stop_experiment()'
validator:
class: NnicliValidator
platform: linux darwin
Expand Down
4 changes: 2 additions & 2 deletions test/config/integration_tests_tf2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ testCases:
config:
maxTrialNum: 4
trialConcurrency: 4
launchCommand: python3 -c 'import nnicli as nc; nc.start_nni("$configFile")'
stopCommand: python3 -c 'import nnicli as nc; nc.stop_nni()'
launchCommand: python3 -c 'import nnicli as nc; nc.start_experiment("$configFile")'
stopCommand: python3 -c 'import nnicli as nc; nc.stop_experiment()'
validator:
class: NnicliValidator
platform: linux darwin
Expand Down
4 changes: 2 additions & 2 deletions test/config/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ testCases:
config:
maxTrialNum: 2
trialConcurrency: 2
launchCommand: python3 -c 'import nnicli as nc; nc.start_nni("$configFile")'
stopCommand: python3 -c 'import nnicli as nc; nc.stop_nni()'
launchCommand: python3 -c 'import nnicli as nc; nc.start_experiment("$configFile")'
stopCommand: python3 -c 'import nnicli as nc; nc.stop_experiment()'
validator:
class: NnicliValidator
platform: linux darwin
Expand Down
2 changes: 1 addition & 1 deletion tools/nni_cmd/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def validate_digit(value, start, end):
'''validate if a digit is valid'''
if not str(value).isdigit() or int(value) < start or int(value) > end:
raise ValueError('%s must be a digit from %s to %s' % (value, start, end))
raise ValueError('%s must be a digit from %s to %s' % ('value', start, end))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is strange that you add quote on value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there used to be a typo here. If the value is 1, it's strange to raise a error like ValueError('1 must be a digit from 1 to 1000' % (value, start, end)). Instead, it should be ValueError('value must be a digit from 1 to 1000' % (value, start, end))

Copy link
Contributor

@QuanluZhang QuanluZhang Aug 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then it should be raise ValueError('value must be a digit from %s to %s' % (start, end)). i still suggest the following:
raise ValueError('value (%s) must be a digit from %s to %s' % (value, start, end))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Goog idea. Will fix.


def validate_file(path):
'''validate if a file exist'''
Expand Down