Skip to content

Commit

Permalink
Fix hard-coded model export dir name (#1015)
Browse files Browse the repository at this point in the history
* FIx hard-coded model export dir name

* Prefix is not the full gcs path.
  • Loading branch information
hongye-sun authored and k8s-ci-robot committed Mar 21, 2019
1 parent a60355a commit 5ccbbb9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 40 deletions.
60 changes: 25 additions & 35 deletions component_sdk/python/kfp_component/google/ml_engine/_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@

@decorators.SetParseFns(python_version=str, runtime_version=str)
def deploy(model_uri, project_id, model_id=None, version_id=None,
runtime_version=None, python_version=None, version=None,
runtime_version=None, python_version=None, model=None, version=None,
replace_existing_version=False, set_default=False, wait_interval=30):
"""Deploy a model to MLEngine from GCS URI
Args:
model_uri (str): required, the GCS URI which contains a model file.
Common used TF model search path (export/exporter) will be used
if exist.
model_uri (str): Required, the GCS URI which contains a model file.
If no model file is found, the same path will be treated as an export
base directory of a TF Estimator. The last time-stamped sub-directory
will be chosen as model URI.
project_id (str): required, the ID of the parent project.
model_id (str): optional, the user provided name of the model.
version_id (str): optional, the user provided name of the version.
Expand All @@ -46,7 +47,10 @@ def deploy(model_uri, project_id, model_id=None, version_id=None,
If not set, the default version is '2.7'. Python '3.5' is available
when runtimeVersion is set to '1.4' and above. Python '2.7' works
with all supported runtime versions.
version (str): optional, the payload of the new version.
model (dict): Optional, the JSON payload of the new model. The schema follows
[REST Model resource](https://cloud.google.com/ml-engine/reference/rest/v1/projects.models).
version (dict): Optional, the JSON payload of the new version. The schema follows
the [REST Version resource](https://cloud.google.com/ml-engine/reference/rest/v1/projects.models.versions)
replace_existing_version (boolean): boolean flag indicates whether to replace
existing version in case of conflict.
set_default (boolean): boolean flag indicates whether to set the new
Expand All @@ -57,7 +61,7 @@ def deploy(model_uri, project_id, model_id=None, version_id=None,
model_uri = _search_dir_with_model(storage_client, model_uri)
gcp_common.dump_file('/tmp/kfp/output/ml_engine/model_uri.txt',
model_uri)
model = create_model(project_id, model_id)
model = create_model(project_id, model_id, model)
model_name = model.get('name')
version = create_version(model_name, model_uri, version_id,
runtime_version, python_version, version, replace_existing_version,
Expand All @@ -78,41 +82,27 @@ def _search_dir_with_model(storage_client, model_root_uri):
if basename in KNOWN_MODEL_NAMES:
logging.info('Found model file under {}.'.format(model_root_uri))
return model_root_uri
model_dir = _search_tf_export_root_dir(storage_client, bucket, blob_name)
model_dir = _search_tf_export_dir_base(storage_client, bucket, blob_name)
if not model_dir:
model_dir = model_root_uri
return model_dir

def _search_tf_export_root_dir(storage_client, bucket, blob_name):
export_root_path = os.path.join(blob_name, 'export/')
logging.info('Searching model under {}.'.format(export_root_path))
it = bucket.list_blobs(prefix=export_root_path, delimiter='/')
def _search_tf_export_dir_base(storage_client, bucket, export_dir_base):
logging.info('Searching model under export base dir: {}.'.format(export_dir_base))
it = bucket.list_blobs(prefix=export_dir_base, delimiter='/')
for _ in it.pages:
# Iterate to the last page to get the full prefixes.
pass
prefixes = it.prefixes
if not prefixes:
logging.info('No model was found under {}. Stop searching.'.format(
export_root_path))
return None
if len(prefixes) > 1:
logging.info('Found multiple dirs under {}. Stop searching.'.format(
export_root_path))
timestamped_dirs = []
for sub_dir in it.prefixes:
dir_name = os.path.basename(os.path.normpath(sub_dir))
if dir_name.isdigit():
timestamped_dirs.append(sub_dir)

if not timestamped_dirs:
logging.info('No timestamped sub-directory is found under {}'.format(export_dir_base))
return None
export_path = list(prefixes)[0]
return _search_tf_export_dir(storage_client, bucket, export_path)

def _search_tf_export_dir(storage_client, bucket, export_path):
it = bucket.list_blobs(prefix=export_path, delimiter='/')
for _ in it.pages:
# Iterate to the last page to get the full prefixes.
pass
prefixes = it.prefixes
if prefixes:
prefixes_list = list(prefixes)
prefixes_list.sort(reverse=True)
logging.info('Found TF model path {}.'.format(prefixes_list[0]))
return 'gs://{}/{}'.format(bucket.name, prefixes_list[0])
logging.info('No model was found under {}. Stop searching.'.format(
export_path))
return None
last_timestamped_dir = max(timestamped_dirs)
logging.info('Found timestamped sub-directory: {}.'.format(last_timestamped_dir))
return 'gs://{}/{}'.format(bucket.name, last_timestamped_dir)
7 changes: 2 additions & 5 deletions component_sdk/python/tests/google/ml_engine/test__deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def test_deploy_tf_exporter_path(self, mock_set_default_version, mock_create_ver
mock_create_model, mock_storage_client):

prefixes_mock = mock.PropertyMock()
prefixes_mock.side_effect = [
set(['uri/export/exporter/']),
set(['uri/export/exporter/123']),
]
prefixes_mock.return_value = set(['uri/012/', 'uri/123/'])
type(mock_storage_client().bucket().list_blobs()).prefixes = prefixes_mock
mock_storage_client().bucket().list_blobs().__iter__.return_value = []
mock_storage_client().bucket().name = 'model'
Expand All @@ -73,7 +70,7 @@ def test_deploy_tf_exporter_path(self, mock_set_default_version, mock_create_ver
self.assertEqual(expected_version, result)
mock_create_version.assert_called_with(
'projects/mock-project/models/mock-model',
'gs://model/uri/export/exporter/123',
'gs://model/uri/123/',
None, # version_name
None, # runtime_version
None, # python_version
Expand Down

0 comments on commit 5ccbbb9

Please sign in to comment.