Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding HPO unit test #3791

Merged
merged 3 commits into from
May 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions components/aws/sagemaker/common/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def create_hyperparameter_tuning_job(client, args):
"""Create a Sagemaker HPO job"""
request = create_hyperparameter_tuning_job_request(args)
try:
job_arn = client.create_hyper_parameter_tuning_job(**request)
client.create_hyper_parameter_tuning_job(**request)
hpo_job_name = request['HyperParameterTuningJobName']
logging.info("Created Hyperparameter Training Job with name: " + hpo_job_name)
logging.info("HPO job in SageMaker: https://{}.console.aws.amazon.com/sagemaker/home?region={}#/hyper-tuning-jobs/{}"
Expand All @@ -604,7 +604,7 @@ def create_hyperparameter_tuning_job(client, args):
raise Exception(e.response['Error']['Message'])


def wait_for_hyperparameter_training_job(client, hpo_job_name):
def wait_for_hyperparameter_training_job(client, hpo_job_name, poll_interval=30):
### Wait until the job finishes
while(True):
response = client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=hpo_job_name)
Expand All @@ -617,7 +617,7 @@ def wait_for_hyperparameter_training_job(client, hpo_job_name):
logging.error('Hyperparameter tuning failed with the following error: {}'.format(message))
raise Exception('Hyperparameter tuning job failed')
logging.info("Hyperparameter tuning job is still in status: " + status)
time.sleep(30)
time.sleep(poll_interval)


def get_best_training_job_and_hyperparameters(client, hpo_job_name):
Expand Down Expand Up @@ -880,4 +880,4 @@ def yaml_or_json_str(str):
def str_to_bool(str):
# This distutils function returns an integer representation of the boolean
# rather than a True/False value. This simply hard casts it.
return bool(strtobool(str))
return bool(strtobool(str))
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def create_parser():

parser.add_argument('--job_name', type=str, required=False, help='The name of the tuning job. Must be unique within the same AWS account and AWS region.')
parser.add_argument('--role', type=str, required=True, help='The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.')
parser.add_argument('--image', type=str, required=True, help='The registry path of the Docker image that contains the training algorithm.', default='')
parser.add_argument('--image', type=str, required=False, help='The registry path of the Docker image that contains the training algorithm.', default='')
parser.add_argument('--algorithm_name', type=str, required=False, help='The name of the resource algorithm to use for the hyperparameter tuning job.', default='')
parser.add_argument('--training_input_mode', choices=['File', 'Pipe'], type=str, required=False, help='The input mode that the algorithm supports. File or Pipe.', default='File')
parser.add_argument('--metric_definitions', type=_utils.yaml_or_json_str, required=False, help='The dictionary of name-regex pairs specify the metrics that the algorithm emits.', default={})
Expand Down Expand Up @@ -65,7 +65,7 @@ def create_parser():

def main(argv=None):
parser = create_parser()
args = parser.parse_args()
args = parser.parse_args(argv)

logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region)
Expand All @@ -92,4 +92,4 @@ def main(argv=None):


if __name__== "__main__":
main()
main(sys.argv[1:])
Loading