From 5ccbbb9a0ca7c23aa8103568ec9f12ddc7c2bf54 Mon Sep 17 00:00:00 2001 From: hongye-sun <43763191+hongye-sun@users.noreply.github.com> Date: Thu, 21 Mar 2019 14:11:38 -0700 Subject: [PATCH] Fix hard-coded model export dir name (#1015) * FIx hard-coded model export dir name * Prefix is not the full gcs path. --- .../kfp_component/google/ml_engine/_deploy.py | 60 ++++++++----------- .../tests/google/ml_engine/test__deploy.py | 7 +-- 2 files changed, 27 insertions(+), 40 deletions(-) diff --git a/component_sdk/python/kfp_component/google/ml_engine/_deploy.py b/component_sdk/python/kfp_component/google/ml_engine/_deploy.py index a1febfd3cb3..be9be3cede6 100644 --- a/component_sdk/python/kfp_component/google/ml_engine/_deploy.py +++ b/component_sdk/python/kfp_component/google/ml_engine/_deploy.py @@ -27,14 +27,15 @@ @decorators.SetParseFns(python_version=str, runtime_version=str) def deploy(model_uri, project_id, model_id=None, version_id=None, - runtime_version=None, python_version=None, version=None, + runtime_version=None, python_version=None, model=None, version=None, replace_existing_version=False, set_default=False, wait_interval=30): """Deploy a model to MLEngine from GCS URI Args: - model_uri (str): required, the GCS URI which contains a model file. - Common used TF model search path (export/exporter) will be used - if exist. + model_uri (str): Required, the GCS URI which contains a model file. + If no model file is found, the same path will be treated as an export + base directory of a TF Estimator. The last time-stamped sub-directory + will be chosen as model URI. project_id (str): required, the ID of the parent project. model_id (str): optional, the user provided name of the model. version_id (str): optional, the user provided name of the version. @@ -46,7 +47,10 @@ def deploy(model_uri, project_id, model_id=None, version_id=None, If not set, the default version is '2.7'. Python '3.5' is available when runtimeVersion is set to '1.4' and above. Python '2.7' works with all supported runtime versions. - version (str): optional, the payload of the new version. + model (dict): Optional, the JSON payload of the new model. The schema follows + [REST Model resource](https://cloud.google.com/ml-engine/reference/rest/v1/projects.models). + version (dict): Optional, the JSON payload of the new version. The schema follows + the [REST Version resource](https://cloud.google.com/ml-engine/reference/rest/v1/projects.models.versions) replace_existing_version (boolean): boolean flag indicates whether to replace existing version in case of conflict. set_default (boolean): boolean flag indicates whether to set the new @@ -57,7 +61,7 @@ def deploy(model_uri, project_id, model_id=None, version_id=None, model_uri = _search_dir_with_model(storage_client, model_uri) gcp_common.dump_file('/tmp/kfp/output/ml_engine/model_uri.txt', model_uri) - model = create_model(project_id, model_id) + model = create_model(project_id, model_id, model) model_name = model.get('name') version = create_version(model_name, model_uri, version_id, runtime_version, python_version, version, replace_existing_version, @@ -78,41 +82,27 @@ def _search_dir_with_model(storage_client, model_root_uri): if basename in KNOWN_MODEL_NAMES: logging.info('Found model file under {}.'.format(model_root_uri)) return model_root_uri - model_dir = _search_tf_export_root_dir(storage_client, bucket, blob_name) + model_dir = _search_tf_export_dir_base(storage_client, bucket, blob_name) if not model_dir: model_dir = model_root_uri return model_dir -def _search_tf_export_root_dir(storage_client, bucket, blob_name): - export_root_path = os.path.join(blob_name, 'export/') - logging.info('Searching model under {}.'.format(export_root_path)) - it = bucket.list_blobs(prefix=export_root_path, delimiter='/') +def _search_tf_export_dir_base(storage_client, bucket, export_dir_base): + logging.info('Searching model under export base dir: {}.'.format(export_dir_base)) + it = bucket.list_blobs(prefix=export_dir_base, delimiter='/') for _ in it.pages: # Iterate to the last page to get the full prefixes. pass - prefixes = it.prefixes - if not prefixes: - logging.info('No model was found under {}. Stop searching.'.format( - export_root_path)) - return None - if len(prefixes) > 1: - logging.info('Found multiple dirs under {}. Stop searching.'.format( - export_root_path)) + timestamped_dirs = [] + for sub_dir in it.prefixes: + dir_name = os.path.basename(os.path.normpath(sub_dir)) + if dir_name.isdigit(): + timestamped_dirs.append(sub_dir) + + if not timestamped_dirs: + logging.info('No timestamped sub-directory is found under {}'.format(export_dir_base)) return None - export_path = list(prefixes)[0] - return _search_tf_export_dir(storage_client, bucket, export_path) -def _search_tf_export_dir(storage_client, bucket, export_path): - it = bucket.list_blobs(prefix=export_path, delimiter='/') - for _ in it.pages: - # Iterate to the last page to get the full prefixes. - pass - prefixes = it.prefixes - if prefixes: - prefixes_list = list(prefixes) - prefixes_list.sort(reverse=True) - logging.info('Found TF model path {}.'.format(prefixes_list[0])) - return 'gs://{}/{}'.format(bucket.name, prefixes_list[0]) - logging.info('No model was found under {}. Stop searching.'.format( - export_path)) - return None + last_timestamped_dir = max(timestamped_dirs) + logging.info('Found timestamped sub-directory: {}.'.format(last_timestamped_dir)) + return 'gs://{}/{}'.format(bucket.name, last_timestamped_dir) diff --git a/component_sdk/python/tests/google/ml_engine/test__deploy.py b/component_sdk/python/tests/google/ml_engine/test__deploy.py index f8327afd42e..7c524625847 100644 --- a/component_sdk/python/tests/google/ml_engine/test__deploy.py +++ b/component_sdk/python/tests/google/ml_engine/test__deploy.py @@ -53,10 +53,7 @@ def test_deploy_tf_exporter_path(self, mock_set_default_version, mock_create_ver mock_create_model, mock_storage_client): prefixes_mock = mock.PropertyMock() - prefixes_mock.side_effect = [ - set(['uri/export/exporter/']), - set(['uri/export/exporter/123']), - ] + prefixes_mock.return_value = set(['uri/012/', 'uri/123/']) type(mock_storage_client().bucket().list_blobs()).prefixes = prefixes_mock mock_storage_client().bucket().list_blobs().__iter__.return_value = [] mock_storage_client().bucket().name = 'model' @@ -73,7 +70,7 @@ def test_deploy_tf_exporter_path(self, mock_set_default_version, mock_create_ver self.assertEqual(expected_version, result) mock_create_version.assert_called_with( 'projects/mock-project/models/mock-model', - 'gs://model/uri/export/exporter/123', + 'gs://model/uri/123/', None, # version_name None, # runtime_version None, # python_version