diff --git a/docs/en_US/Tutorial/Nnictl.md b/docs/en_US/Tutorial/Nnictl.md index 28084494a3..edb25afacd 100644 --- a/docs/en_US/Tutorial/Nnictl.md +++ b/docs/en_US/Tutorial/Nnictl.md @@ -578,6 +578,7 @@ Debug mode will disable version check function in Trialkeeper. |--path, -p| True| |the file path of nni package| |--codeDir, -c| True| |the path of codeDir for loaded experiment, this path will also put the code in the loaded experiment package| |--logDir, -l| False| |the path of logDir for loaded experiment| + |--searchSpacePath, -s| True| |the path of search space file for loaded experiment, this path contains file name. Default in $codeDir/search_space.json| * Examples diff --git a/tools/nni_cmd/nnictl.py b/tools/nni_cmd/nnictl.py index 213554d5e8..cc5d4cdd11 100644 --- a/tools/nni_cmd/nnictl.py +++ b/tools/nni_cmd/nnictl.py @@ -159,6 +159,8 @@ def parse_args(): parser_load_experiment.add_argument('--codeDir', '-c', required=True, help='the path of codeDir for loaded experiment, \ this path will also put the code in the loaded experiment package') parser_load_experiment.add_argument('--logDir', '-l', required=False, help='the path of logDir for loaded experiment') + parser_load_experiment.add_argument('--searchSpacePath', '-s', required=False, help='the path of search space file for \ + loaded experiment, this path contains file name. Default in $codeDir/search_space.json') parser_load_experiment.set_defaults(func=load_experiment) #parse platform command diff --git a/tools/nni_cmd/nnictl_utils.py b/tools/nni_cmd/nnictl_utils.py index 3833c9317f..e59c293e2a 100644 --- a/tools/nni_cmd/nnictl_utils.py +++ b/tools/nni_cmd/nnictl_utils.py @@ -827,7 +827,18 @@ def save_experiment(args): temp_code_dir = os.path.join(temp_root_dir, 'code') shutil.copytree(nni_config.get_config('experimentConfig')['trial']['codeDir'], temp_code_dir) - # Step4. Archive folder + # Step4. Copy searchSpace file + search_space_path = nni_config.get_config('experimentConfig').get('searchSpacePath') + if search_space_path: + if not os.path.exists(search_space_path): + print_warning('search space %s does not exist!' % search_space_path) + else: + temp_search_space_dir = os.path.join(temp_root_dir, 'searchSpace') + os.makedirs(temp_search_space_dir, exist_ok=True) + search_space_name = os.path.basename(search_space_path) + shutil.copyfile(search_space_path, os.path.join(temp_search_space_dir, search_space_name)) + + # Step5. Archive folder zip_package_name = 'nni_experiment_%s' % args.id if args.path: os.makedirs(args.path, exist_ok=True) @@ -844,6 +855,9 @@ def load_experiment(args): if not os.path.exists(args.path): print_error('file path %s does not exist!' % args.path) exit(1) + if args.searchSpacePath and os.path.isdir(args.searchSpacePath): + print_error('search space path should be a full path with filename, not a directory!') + exit(1) temp_root_dir = generate_temp_dir() shutil.unpack_archive(package_path, temp_root_dir) print_normal('Loading...') @@ -929,7 +943,32 @@ def load_experiment(args): else: shutil.copy(src_path, target_path) - # Step5. Create experiment metadata + # Step5. Copy searchSpace file + archive_search_space_dir = os.path.join(temp_root_dir, 'searchSpace') + if args.searchSpacePath: + target_path = os.path.expanduser(args.searchSpacePath) + else: + # set default path to codeDir + target_path = os.path.join(codeDir, 'search_space.json') + if not os.path.isabs(target_path): + target_path = os.path.join(os.getcwd(), target_path) + print_normal('Expand search space path to %s' % target_path) + nnictl_exp_config['searchSpacePath'] = target_path + # if the path already has a search space file, use the original one, otherwise use archived one + if not os.path.isfile(target_path): + if len(os.listdir(archive_search_space_dir)) == 0: + print_error('Archive file does not contain search space file!') + exit(1) + else: + for file in os.listdir(archive_search_space_dir): + source_path = os.path.join(archive_search_space_dir, file) + os.makedirs(os.path.dirname(target_path), exist_ok=True) + shutil.copyfile(source_path, target_path) + break + elif not args.searchSpacePath: + print_warning('%s exist, will not load search_space file' % target_path) + + # Step6. Create experiment metadata nni_config.set_config('experimentConfig', nnictl_exp_config) experiment_config.add_experiment(experiment_id, experiment_metadata.get('port'),