diff --git a/component_sdk/python/kfp_component/core/__init__.py b/component_sdk/python/kfp_component/core/__init__.py index 7b1bc3a28bd..804e4cea535 100644 --- a/component_sdk/python/kfp_component/core/__init__.py +++ b/component_sdk/python/kfp_component/core/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._kfp_execution_context import KfpExecutionContext \ No newline at end of file +from ._kfp_execution_context import KfpExecutionContext +from . import _display as display \ No newline at end of file diff --git a/component_sdk/python/kfp_component/core/_display.py b/component_sdk/python/kfp_component/core/_display.py new file mode 100644 index 00000000000..6f4c295a694 --- /dev/null +++ b/component_sdk/python/kfp_component/core/_display.py @@ -0,0 +1,104 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import threading +import logging + +_OUTPUT_PATH = os.environ.get('KFP_UI_METADATA_PATH', '/mlpipeline-ui-metadata.json') +_OUTPUT_FILE_LOCK = threading.Lock() + +def display(obj): + """Display an object to KFP UI. + + Args: + obj (object): the object to output the display metadata. It follows same + convention defined by IPython display API. The currently supported representation + functions: + + * `_repr_html_`: it returns a html content which will be converted into a + web-app metadata to KFP UI. + * `_repr_kfpmetadata_`: it returns a KFP metadata json object, which follows + the convention from https://www.kubeflow.org/docs/pipelines/output-viewer/. + + The supported builtin objects are HTML, Tensorboard, Link. + """ + obj_dir = dir(obj) + if '_repr_html_' in obj_dir: + display_html(obj) + + if '_repr_kfpmetadata_' in obj_dir: + display_kfpmetadata(obj) + +def display_html(obj): + """Display html representation to KFP UI. + """ + if '_repr_html_' not in dir(obj): + raise ValueError('_repr_html_ function is not present.') + html = obj._repr_html_() + _output_ui_metadata({ + 'type': 'web-app', + 'html': html + }) + +def display_kfpmetadata(obj): + """Display from KFP UI metadata + """ + if '_repr_kfpmetadata_' not in dir(obj): + raise ValueError('_repr_kfpmetadata_ function is not present.') + kfp_metadata = obj._repr_kfpmetadata_() + _output_ui_metadata(kfp_metadata) + +def _output_ui_metadata(output): + logging.info('Dumping metadata: {}'.format(output)) + with _OUTPUT_FILE_LOCK: + metadata = {} + if os.path.isfile(_OUTPUT_PATH): + with open(_OUTPUT_PATH, 'r') as f: + metadata = json.load(f) + + with open(_OUTPUT_PATH, 'w') as f: + if 'outputs' not in metadata: + metadata['outputs'] = [] + metadata['outputs'].append(output) + json.dump(metadata, f) + +class HTML(object): + """Class to hold html raw data. + """ + def __init__(self, data): + self._html = data + + def _repr_html_(self): + return self._html + +class Tensorboard(object): + """Class to hold tensorboard metadata. + """ + def __init__(self, job_dir): + self._job_dir = job_dir + + def _repr_kfpmetadata_(self): + return { + 'type': 'tensorboard', + 'source': self._job_dir + } + +class Link(HTML): + """Class to hold an HTML hyperlink data. + """ + def __init__(self, href, text): + super(Link, self).__init__( + '{}'.format(href, text)) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/__init__.py b/component_sdk/python/kfp_component/google/__init__.py new file mode 100644 index 00000000000..c2fc82ab83f --- /dev/null +++ b/component_sdk/python/kfp_component/google/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/common/__init__.py b/component_sdk/python/kfp_component/google/common/__init__.py new file mode 100644 index 00000000000..df0f4289e20 --- /dev/null +++ b/component_sdk/python/kfp_component/google/common/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._utils import normalize_name, dump_file, check_resource_changed diff --git a/component_sdk/python/kfp_component/google/common/_utils.py b/component_sdk/python/kfp_component/google/common/_utils.py new file mode 100644 index 00000000000..4f7286ac8c7 --- /dev/null +++ b/component_sdk/python/kfp_component/google/common/_utils.py @@ -0,0 +1,92 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +import os + +def normalize_name(name, + valid_first_char_pattern='a-zA-Z', + valid_char_pattern='0-9a-zA-Z_', + invalid_char_placeholder='_', + prefix_placeholder='x_'): + """Normalize a name to a valid resource name. + + Uses ``valid_first_char_pattern`` and ``valid_char_pattern`` regex pattern + to find invalid characters from ``name`` and replaces them with + ``invalid_char_placeholder`` or prefix the name with ``prefix_placeholder``. + + Args: + name: The name to be normalized. + valid_first_char_pattern: The regex pattern for the first character. + valid_char_pattern: The regex pattern for all the characters in the name. + invalid_char_placeholder: The placeholder to replace invalid characters. + prefix_placeholder: The placeholder to prefix the name if the first char + is invalid. + + Returns: + The normalized name. Unchanged if all characters are valid. + """ + if not name: + return name + normalized_name = re.sub('[^{}]+'.format(valid_char_pattern), + invalid_char_placeholder, name) + if not re.match('[{}]'.format(valid_first_char_pattern), + normalized_name[0]): + normalized_name = prefix_placeholder + normalized_name + if name != normalized_name: + logging.info('Normalize name from "{}" to "{}".'.format( + name, normalized_name)) + return normalized_name + +def dump_file(path, content): + """Dumps string into local file. + + Args: + path: the local path to the file. + content: the string content to dump. + """ + directory = os.path.dirname(path) + if not os.path.exists(directory): + os.makedirs(directory) + elif os.path.exists(path): + logging.warning('The file {} will be overwritten.'.format(path)) + with open(path, 'w') as f: + f.write(content) + +def check_resource_changed(requested_resource, + existing_resource, property_names): + """Check if a resource has been changed. + + The function checks requested resource with existing resource + by comparing specified property names. Check fails if any property + name in the list is in ``requested_resource`` but its value is + different with the value in ``existing_resource``. + + Args: + requested_resource: the user requested resource paylod. + existing_resource: the existing resource payload from data storage. + property_names: a list of property names. + + Return: + True if ``requested_resource`` has been changed. + """ + for property_name in property_names: + if not property_name in requested_resource: + continue + existing_value = existing_resource.get(property_name, None) + if requested_resource[property_name] != existing_value: + return True + return False + diff --git a/component_sdk/python/kfp_component/google/ml_engine/__init__.py b/component_sdk/python/kfp_component/google/ml_engine/__init__.py new file mode 100644 index 00000000000..7075ef99d18 --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module that contains a set of commands to call ML Engine APIs + +The commands are aware of KFP execution context and can work under +retry and cancellation context. The currently supported commands +are: train, batch_prediction, create_model, create_version and +delete_version. + +TODO(hongyes): Provides full ML Engine API support. +""" + +from ._create_job import create_job +from ._create_model import create_model +from ._create_version import create_version +from ._delete_version import delete_version +from ._train import train +from ._batch_predict import batch_predict \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/ml_engine/_batch_predict.py b/component_sdk/python/kfp_component/google/ml_engine/_batch_predict.py new file mode 100644 index 00000000000..ef086bcc011 --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_batch_predict.py @@ -0,0 +1,79 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from ._create_job import create_job + +def batch_predict(project_id, model_path, input_paths, input_data_format, + output_path, region, output_data_format=None, prediction_input=None, job_id_prefix=None, + wait_interval=30): + """Creates a MLEngine batch prediction job. + + Args: + project_id (str): Required. The ID of the parent project of the job. + model_path (str): Required. The path to the model. It can be either: + `projects/[PROJECT_ID]/models/[MODEL_ID]` or + `projects/[PROJECT_ID]/models/[MODEL_ID]/versions/[VERSION_ID]` + or a GCS path of a model file. + input_paths (list): Required. The Google Cloud Storage location of + the input data files. May contain wildcards. + input_data_format (str): Required. The format of the input data files. + See https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#DataFormat. + output_path (str): Required. The output Google Cloud Storage location. + region (str): Required. The Google Compute Engine region to run the + prediction job in. + output_data_format (str): Optional. Format of the output data files, + defaults to JSON. + prediction_input (dict): Input parameters to create a prediction job. + job_id_prefix (str): the prefix of the generated job id. + wait_interval (int): optional wait interval between calls + to get job status. Defaults to 30. + """ + if not prediction_input: + prediction_input = {} + if not model_path: + raise ValueError('model_path must be provided.') + if _is_model_name(model_path): + prediction_input['modelName'] = model_path + elif _is_model_version_name(model_path): + prediction_input['versionName'] = model_path + elif _is_gcs_path(model_path): + prediction_input['uri'] = model_path + else: + raise ValueError('model_path value is invalid.') + + if input_paths: + prediction_input['inputPaths'] = input_paths + if input_data_format: + prediction_input['dataFormat'] = input_data_format + if output_path: + prediction_input['outputPath'] = output_path + if output_data_format: + prediction_input['outputDataFormat'] = output_data_format + if region: + prediction_input['region'] = region + job = { + 'predictionInput': prediction_input + } + create_job(project_id, job, job_id_prefix, wait_interval) + +def _is_model_name(name): + return re.match(r'/projects/[^/]+/models/[^/]+$', name) + +def _is_model_version_name(name): + return re.match(r'/projects/[^/]+/models/[^/]+/versions/[^/]+$', name) + +def _is_gcs_path(name): + return name.startswith('gs://') \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/ml_engine/_client.py b/component_sdk/python/kfp_component/google/ml_engine/_client.py new file mode 100644 index 00000000000..f40f1e2aa94 --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_client.py @@ -0,0 +1,208 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +import googleapiclient.discovery as discovery +from googleapiclient import errors + +class MLEngineClient: + """ Client for calling MLEngine APIs. + """ + def __init__(self): + self._ml_client = discovery.build('ml', 'v1') + + def create_job(self, project_id, job): + """Create a new job. + + Args: + project_id: the ID of the parent project. + job: the payload of the job. + + Returns: + The created job. + """ + return self._ml_client.projects().jobs().create( + parent = 'projects/{}'.format(project_id), + body = job + ).execute() + + def cancel_job(self, project_id, job_id): + """Cancel the specified job. + + Args: + project_id: the parent project ID of the job. + job_id: the ID of the job. + """ + job_name = 'projects/{}/jobs/{}'.format(project_id, job_id) + self._ml_client.projects().jobs().cancel( + name = job_name, + body = { + 'name': job_name + }, + ).execute() + + def get_job(self, project_id, job_id): + """Gets the job by ID. + + Args: + project_id: the ID of the parent project. + job_id: the ID of the job to retrieve. + Returns: + The retrieved job payload. + """ + job_name = 'projects/{}/jobs/{}'.format(project_id, job_id) + return self._ml_client.projects().jobs().get( + name=job_name).execute() + + def create_model(self, project_id, model): + """Creates a new model. + + Args: + project_id: the ID of the parent project. + model: the payload of the model. + Returns: + The created model. + """ + return self._ml_client.projects().models().create( + parent = 'projects/{}'.format(project_id), + body = model + ).execute() + + def get_model(self, project_id, model_name): + """Gets a model. + + Args: + project_id: the ID of the parent project. + model_name: the name of the model. + Returns: + The retrieved model. + """ + return self._ml_client.projects().models().get( + name = 'projects/{}/models/{}'.format( + project_id, model_name) + ).execute() + + def create_version(self, project_id, model_name, version): + """Creates a new version. + + Args: + project_id: the ID of the parent project. + model_name: the name of the parent model. + version: the payload of the version. + + Returns: + The created version. + """ + return self._ml_client.projects().models().versions().create( + parent = 'projects/{}/models/{}'.format(project_id, model_name), + body = version + ).execute() + + def get_version(self, project_id, model_name, version_name): + """Gets a version. + + Args: + project_id: the ID of the parent project. + model_name: the name of the parent model. + version_name: the name of the version. + + Returns: + The retrieved version. None if the version is not found. + """ + try: + return self._ml_client.projects().models().versions().get( + name = 'projects/{}/models/{}/versions/{}'.format( + project_id, model_name, version_name) + ).execute() + except errors.HttpError as e: + if e.resp.status == 404: + return None + raise + + def delete_version(self, project_id, model_name, version_name): + """Deletes a version. + + Args: + project_id: the ID of the parent project. + model_name: the name of the parent model. + version_name: the name of the version. + + Returns: + The delete operation. None if the version is not found. + """ + try: + return self._ml_client.projects().models().versions().delete( + name = 'projects/{}/models/{}/versions/{}'.format( + project_id, model_name, version_name) + ).execute() + except errors.HttpError as e: + if e.resp.status == 404: + logging.info('The version has already been deleted.') + return None + raise + + def get_operation(self, operation_name): + """Gets an operation. + + Args: + operation_name: the name of the operation. + + Returns: + The retrieved operation. + """ + return self._ml_client.projects().operations().get( + name = operation_name + ).execute() + + def wait_for_operation_done(self, operation_name, wait_interval): + """Waits for an operation to be done. + + Args: + operation_name: the name of the operation. + wait_interval: the wait interview between pulling job + status. + + Returns: + The completed operation. + """ + operation = None + while True: + operation = self._ml_client.projects().operations().get( + name = operation_name + ).execute() + done = operation.get('done', False) + if done: + break + logging.info('Operation {} is not done. Wait for {}s.'.format(operation_name, wait_interval)) + time.sleep(wait_interval) + error = operation.get('error', None) + if error: + raise RuntimeError('Failed to complete operation {}: {} {}'.format( + operation_name, + error.get('code', 'Unknown code'), + error.get('message', 'Unknown message'), + )) + return operation + + def cancel_operation(self, operation_name): + """Cancels an operation. + + Args: + operation_name: the name of the operation. + """ + self._ml_client.projects().operations().cancel( + name = operation_name + ).execute() \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/ml_engine/_common_ops.py b/component_sdk/python/kfp_component/google/ml_engine/_common_ops.py new file mode 100644 index 00000000000..23b6008e999 --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_common_ops.py @@ -0,0 +1,66 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from googleapiclient import errors + +def wait_existing_version(ml_client, project_id, model_name, + version_name, wait_interval): + while True: + existing_version = ml_client.get_version( + project_id, model_name, version_name) + if not existing_version: + return None + state = existing_version.get('state', None) + if not state in ['CREATING', 'DELETING', 'UPDATING']: + return existing_version + logging.info('Version is in {} state. Wait for {}s'.format( + state, wait_interval + )) + time.sleep(wait_interval) + +def wait_for_operation_done(ml_client, operation_name, action, wait_interval): + """Waits for an operation to be done. + + Args: + operation_name: the name of the operation. + action: the action name of the operation. + wait_interval: the wait interview between pulling job + status. + + Returns: + The completed operation. + + Raises: + RuntimeError if the operation has error. + """ + operation = None + while True: + operation = ml_client.get_operation(operation_name) + done = operation.get('done', False) + if done: + break + logging.info('Operation {} is not done. Wait for {}s.'.format(operation_name, wait_interval)) + time.sleep(wait_interval) + error = operation.get('error', None) + if error: + raise RuntimeError('Failed to complete {} operation {}: {} {}'.format( + action, + operation_name, + error.get('code', 'Unknown code'), + error.get('message', 'Unknown message'), + )) + return operation \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/ml_engine/_create_job.py b/component_sdk/python/kfp_component/google/ml_engine/_create_job.py new file mode 100644 index 00000000000..c682e108967 --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_create_job.py @@ -0,0 +1,131 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import logging +import re +import time + +from googleapiclient import errors + +from kfp_component.core import KfpExecutionContext, display +from ._client import MLEngineClient +from .. import common as gcp_common + +def create_job(project_id, job, job_id_prefix=None, wait_interval=30): + """Creates a MLEngine job. + + Args: + project_id: the ID of the parent project of the job. + job: the payload of the job. Must have ``jobId`` + and ``trainingInput`` or ``predictionInput`. + job_id_prefix: the prefix of the generated job id. + wait_interval: optional wait interval between calls + to get job status. Defaults to 30. + + """ + return CreateJobOp(project_id, job, job_id_prefix, + wait_interval).execute_and_wait() + +class CreateJobOp: + def __init__(self, project_id, job, job_id_prefix=None, wait_interval=30): + self._ml = MLEngineClient() + self._project_id = project_id + self._job_id_prefix = job_id_prefix + self._job_id = None + self._job = job + self._wait_interval = wait_interval + + def execute_and_wait(self): + with KfpExecutionContext(on_cancel=self._cancel) as ctx: + self._set_job_id(ctx.context_id()) + self._dump_metadata() + self._create_job() + finished_job = self._wait_for_done() + self._dump_job(finished_job) + if finished_job['state'] != 'SUCCEEDED': + raise RuntimeError('Job failed with state {}. Error: {}'.format( + finished_job['state'], finished_job.get('errorMessage', ''))) + return finished_job + + def _set_job_id(self, context_id): + if self._job_id_prefix: + job_id = self._job_id_prefix + context_id[:16] + else: + job_id = 'job_' + context_id + job_id = gcp_common.normalize_name(job_id) + self._job_id = job_id + self._job['jobId'] = job_id + + def _cancel(self): + try: + logging.info('Cancelling job {}.'.format(self._job_id)) + self._ml.cancel_job(self._project_id, self._job_id) + logging.info('Cancelled job {}.'.format(self._job_id)) + except errors.HttpError as e: + # Best effort to cancel the job + logging.error('Failed to cancel the job: {}'.format(e)) + pass + + def _create_job(self): + try: + self._ml.create_job( + project_id = self._project_id, + job = self._job + ) + except errors.HttpError as e: + if e.resp.status == 409: + if not self._is_dup_job(): + logging.error('Another job has been created with same name before: {}'.format(self._job_id)) + raise + logging.info('The job {} has been submitted before. Continue waiting.'.format(self._job_id)) + else: + logging.error('Failed to create job.\nPayload: {}\nError: {}'.format(self._job, e)) + raise + + def _is_dup_job(self): + existing_job = self._ml.get_job(self._project_id, self._job_id) + return existing_job.get('trainingInput', None) == self._job.get('trainingInput', None) \ + and existing_job.get('predictionInput', None) == self._job.get('predictionInput', None) + + def _wait_for_done(self): + while True: + job = self._ml.get_job(self._project_id, self._job_id) + if job.get('state', None) in ['SUCCEEDED', 'FAILED', 'CANCELLED']: + return job + # Move to config from flag + logging.info('job status is {}, wait for {}s'.format( + job.get('state', None), self._wait_interval)) + time.sleep(self._wait_interval) + + def _dump_metadata(self): + display.display(display.Link( + 'https://console.cloud.google.com/mlengine/jobs/{}?project={}'.format( + self._job_id, self._project_id), + 'Job Details' + )) + display.display(display.Link( + 'https://console.cloud.google.com/logs/viewer?project={}&resource=ml_job/job_id/{}&interval=NO_LIMIT'.format( + self._project_id, self._job_id), + 'Logs' + )) + if 'trainingInput' in self._job and 'jobDir' in self._job['trainingInput']: + display.display(display.Tensorboard( + self._job['trainingInput']['jobDir'])) + + def _dump_job(self, job): + logging.info('Dumping job: {}'.format(job)) + gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(job)) + gcp_common.dump_file('/tmp/outputs/job_id.txt', job['jobId']) diff --git a/component_sdk/python/kfp_component/google/ml_engine/_create_model.py b/component_sdk/python/kfp_component/google/ml_engine/_create_model.py new file mode 100644 index 00000000000..11504069372 --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_create_model.py @@ -0,0 +1,92 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import logging + +from googleapiclient import errors + +from kfp_component.core import KfpExecutionContext, display +from ._client import MLEngineClient +from .. import common as gcp_common + +def create_model(project_id, name=None, model=None): + """Creates a MLEngine model. + + Args: + project_id (str): the ID of the parent project of the model. + name (str): optional, the name of the model. If absent, a new name will + be generated. + model (dict): the payload of the model. + """ + return CreateModelOp(project_id, name, model).execute() + +class CreateModelOp: + def __init__(self, project_id, name, model): + self._ml = MLEngineClient() + self._project_id = project_id + self._model_name = name + if model: + self._model = model + else: + self._model = {} + + def execute(self): + with KfpExecutionContext() as ctx: + self._set_model_name(ctx.context_id()) + self._dump_metadata() + try: + created_model = self._ml.create_model( + project_id = self._project_id, + model = self._model) + except errors.HttpError as e: + if e.resp.status == 409: + existing_model = self._ml.get_model( + self._project_id, self._model_name) + if not self._is_dup_model(existing_model): + raise + logging.info('The same model {} has been submitted' + ' before. Continue the operation.'.format( + self._model_name)) + created_model = existing_model + else: + raise + self._dump_model(created_model) + return created_model + + def _set_model_name(self, context_id): + if not self._model_name: + self._model_name = 'model_' + context_id + self._model['name'] = gcp_common.normalize_name(self._model_name) + + + def _is_dup_model(self, existing_model): + return not gcp_common.check_resource_changed( + self._model, + existing_model, + ['description', 'regions', + 'onlinePredictionLogging', 'labels']) + + def _dump_metadata(self): + display.display(display.Link( + 'https://console.cloud.google.com/mlengine/models/{}?project={}'.format( + self._model_name, self._project_id), + 'Model Details' + )) + + def _dump_model(self, model): + logging.info('Dumping model: {}'.format(model)) + gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(model)) + gcp_common.dump_file('/tmp/outputs/model_name.txt', self._model_name) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/ml_engine/_create_version.py b/component_sdk/python/kfp_component/google/ml_engine/_create_version.py new file mode 100644 index 00000000000..fdea66c56ec --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_create_version.py @@ -0,0 +1,172 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import time + +from googleapiclient import errors +from fire import decorators + +from kfp_component.core import KfpExecutionContext, display +from ._client import MLEngineClient +from .. import common as gcp_common +from ._common_ops import wait_existing_version, wait_for_operation_done + +@decorators.SetParseFns(python_version=str, runtime_version=str) +def create_version(project_id, model_name, deployemnt_uri=None, version_name=None, + runtime_version=None, python_version=None, version=None, + replace_existing=False, wait_interval=30): + """Creates a MLEngine version and wait for the operation to be done. + + Args: + project_id (str): required, the ID of the parent project. + model_name (str): required, the name of the parent model. + deployment_uri (str): optional, the Google Cloud Storage location of + the trained model used to create the version. + version_name (str): optional, the name of the version. If it is not + provided, the operation uses a random name. + runtime_version (str): optinal, the Cloud ML Engine runtime version + to use for this deployment. If not set, Cloud ML Engine uses + the default stable version, 1.0. + python_version (str): optinal, the version of Python used in prediction. + If not set, the default version is '2.7'. Python '3.5' is available + when runtimeVersion is set to '1.4' and above. Python '2.7' works + with all supported runtime versions. + version (str): optional, the payload of the new version. + replace_existing (boolean): boolean flag indicates whether to replace + existing version in case of conflict. + wait_interval (int): the interval to wait for a long running operation. + """ + if not version: + version = {} + if deployemnt_uri: + version['deploymentUri'] = deployemnt_uri + if version_name: + version['name'] = version_name + if runtime_version: + version['runtimeVersion'] = runtime_version + if python_version: + version['pythonVersion'] = python_version + + return CreateVersionOp(project_id, model_name, version, + replace_existing, wait_interval).execute_and_wait() + +class CreateVersionOp: + def __init__(self, project_id, model_name, version, + replace_existing, wait_interval): + self._ml = MLEngineClient() + self._project_id = project_id + self._model_name = gcp_common.normalize_name(model_name) + self._version_name = None + self._version = version + self._replace_existing = replace_existing + self._wait_interval = wait_interval + self._create_operation_name = None + self._delete_operation_name = None + + def execute_and_wait(self): + with KfpExecutionContext(on_cancel=self._cancel) as ctx: + self._set_version_name(ctx.context_id()) + self._dump_metadata() + existing_version = wait_existing_version(self._ml, + self._project_id, self._model_name, self._version_name, + self._wait_interval) + if existing_version and self._is_dup_version(existing_version): + return self._handle_completed_version(existing_version) + + if existing_version and self._replace_existing: + logging.info('Deleting existing version...') + self._delete_version_and_wait() + elif existing_version: + raise RuntimeError( + 'Existing version conflicts with the name of the new version.') + + created_version = self._create_version_and_wait() + return self._handle_completed_version(created_version) + + def _set_version_name(self, context_id): + version_name = self._version.get('name', None) + if not version_name: + version_name = 'ver_' + context_id + version_name = gcp_common.normalize_name(version_name) + self._version_name = version_name + self._version['name'] = version_name + + + def _cancel(self): + if self._delete_operation_name: + self._ml.cancel_operation(self._delete_operation_name) + + if self._create_operation_name: + self._ml.cancel_operation(self._create_operation_name) + + def _create_version_and_wait(self): + operation = self._ml.create_version(self._project_id, + self._model_name, self._version) + # Cache operation name for cancellation. + self._create_operation_name = operation.get('name') + try: + operation = wait_for_operation_done( + self._ml, + self._create_operation_name, + 'create version', + self._wait_interval) + finally: + self._create_operation_name = None + return operation.get('response', None) + + def _delete_version_and_wait(self): + operation = self._ml.delete_version( + self._project_id, self._model_name, self._version_name) + # Cache operation name for cancellation. + self._delete_operation_name = operation.get('name') + try: + wait_for_operation_done( + self._ml, + self._delete_operation_name, + 'delete version', + self._wait_interval) + finally: + self._delete_operation_name = None + + def _handle_completed_version(self, version): + state = version.get('state', None) + if state == 'FAILED': + error_message = version.get('errorMessage', 'Unknown failure') + raise RuntimeError('Version is in failed state: {}'.format( + error_message)) + self._dump_version(version) + return version + + def _dump_metadata(self): + display.display(display.Link( + 'https://console.cloud.google.com/mlengine/models/{}/versions/{}?project={}'.format( + self._model_name, self._version_name, self._project_id), + 'Version Details' + )) + + def _dump_version(self, version): + logging.info('Dumping version: {}'.format(version)) + gcp_common.dump_file('/tmp/outputs/output.txt', json.dumps(version)) + gcp_common.dump_file('/tmp/outputs/version_name.txt', version['name']) + + def _is_dup_version(self, existing_version): + return not gcp_common.check_resource_changed( + self._version, + existing_version, + ['description', 'deploymentUri', + 'runtimeVersion', 'machineType', 'labels', + 'framework', 'pythonVersion', 'autoScaling', + 'manualScaling']) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/ml_engine/_delete_version.py b/component_sdk/python/kfp_component/google/ml_engine/_delete_version.py new file mode 100644 index 00000000000..4bc68e2205f --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_delete_version.py @@ -0,0 +1,72 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging + +from googleapiclient import errors + +from kfp_component.core import KfpExecutionContext +from ._client import MLEngineClient +from .. import common as gcp_common +from ._common_ops import wait_existing_version, wait_for_operation_done + +def delete_version(project_id, model_name, version_name, wait_interval=30): + """Deletes a MLEngine version and wait. + + Args: + project_id (str): required, the ID of the parent project. + model_name (str): required, the name of the parent model. + version_name (str): required, the name of the version. + wait_interval (int): the interval to wait for a long running operation. + """ + DeleteVersionOp(project_id, model_name, version_name, + wait_interval).execute_and_wait() + +class DeleteVersionOp: + def __init__(self, project_id, model_name, version_name, wait_interval): + self._ml = MLEngineClient() + self._project_id = project_id + self._model_name = gcp_common.normalize_name(model_name) + self._version_name = gcp_common.normalize_name(version_name) + self._wait_interval = wait_interval + self._delete_operation_name = None + + def execute_and_wait(self): + with KfpExecutionContext(on_cancel=self._cancel): + existing_version = wait_existing_version(self._ml, + self._project_id, self._model_name, self._version_name, + self._wait_interval) + if not existing_version: + logging.info('The version has already been deleted.') + return None + + logging.info('Deleting existing version...') + operation = self._ml.delete_version( + self._project_id, self._model_name, self._version_name) + # Cache operation name for cancellation. + self._delete_operation_name = operation.get('name') + try: + wait_for_operation_done( + self._ml, + self._delete_operation_name, + 'delete version', + self._wait_interval) + finally: + self._delete_operation_name = None + return None + + def _cancel(self): + if self._delete_operation_name: + self._ml.cancel_operation(self._delete_operation_name) \ No newline at end of file diff --git a/component_sdk/python/kfp_component/google/ml_engine/_train.py b/component_sdk/python/kfp_component/google/ml_engine/_train.py new file mode 100644 index 00000000000..b32be1fff71 --- /dev/null +++ b/component_sdk/python/kfp_component/google/ml_engine/_train.py @@ -0,0 +1,71 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fire import decorators +from ._create_job import create_job + +@decorators.SetParseFns(python_version=str, runtime_version=str) +def train(project_id, python_module, package_uris, + region, args=None, job_dir=None, python_version=None, + runtime_version=None, training_input=None, job_id_prefix=None, + wait_interval=30): + """Creates a MLEngine training job. + + Args: + project_id (str): Required. The ID of the parent project of the job. + python_module (str): Required. The Python module name to run after + installing the packages. + package_uris (list): Required. The Google Cloud Storage location of + the packages with the training program and any additional + dependencies. The maximum number of package URIs is 100. + region (str): Required. The Google Compute Engine region to run the + training job in + args (list): Command line arguments to pass to the program. + job_dir (str): A Google Cloud Storage path in which to store training + outputs and other data needed for training. This path is passed + to your TensorFlow program as the '--job-dir' command-line + argument. The benefit of specifying this field is that Cloud ML + validates the path for use in training. + python_version (str): Optional. The version of Python used in training. + If not set, the default version is '2.7'. Python '3.5' is + available when runtimeVersion is set to '1.4' and above. + Python '2.7' works with all supported runtime versions. + runtime_version (str): Optional. The Cloud ML Engine runtime version + to use for training. If not set, Cloud ML Engine uses the + default stable version, 1.0. + training_input (dict): Input parameters to create a training job. + job_id_prefix (str): the prefix of the generated job id. + wait_interval (int): optional wait interval between calls + to get job status. Defaults to 30. + """ + if not training_input: + training_input = {} + if python_module: + training_input['pythonModule'] = python_module + if package_uris: + training_input['packageUris'] = package_uris + if region: + training_input['region'] = region + if args: + training_input['args'] = args + if job_dir: + training_input['jobDir'] = job_dir + if python_version: + training_input['pythonVersion'] = python_version + if runtime_version: + training_input['runtimeVersion'] = runtime_version + job = { + 'trainingInput': training_input + } + return create_job(project_id, job, job_id_prefix, wait_interval) \ No newline at end of file diff --git a/component_sdk/python/requirements.txt b/component_sdk/python/requirements.txt index 187c82d92aa..2aafe32c46b 100644 --- a/component_sdk/python/requirements.txt +++ b/component_sdk/python/requirements.txt @@ -1,2 +1,3 @@ kubernetes == 8.0.1 fire == 0.1.3 +google-api-python-client == 1.7.8 diff --git a/component_sdk/python/tests/core/test__display.py b/component_sdk/python/tests/core/test__display.py new file mode 100644 index 00000000000..04aee8a53cf --- /dev/null +++ b/component_sdk/python/tests/core/test__display.py @@ -0,0 +1,80 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kfp_component.core import display + +import mock +import unittest + +@mock.patch('kfp_component.core._display.json') +@mock.patch('kfp_component.core._display.os') +@mock.patch('kfp_component.core._display.open') +class DisplayTest(unittest.TestCase): + + def test_display_html(self, mock_open, mock_os, mock_json): + mock_os.path.isfile.return_value = False + + display.display(display.HTML('

test

')) + + mock_json.dump.assert_called_with({ + 'outputs': [{ + 'type': 'web-app', + 'html': '

test

' + }] + }, mock.ANY) + + def test_display_html_append(self, mock_open, mock_os, mock_json): + mock_os.path.isfile.return_value = True + mock_json.load.return_value = { + 'outputs': [{ + 'type': 'web-app', + 'html': '

test 1

' + }] + } + + display.display(display.HTML('

test 2

')) + + mock_json.dump.assert_called_with({ + 'outputs': [{ + 'type': 'web-app', + 'html': '

test 1

' + },{ + 'type': 'web-app', + 'html': '

test 2

' + }] + }, mock.ANY) + + def test_display_tensorboard(self, mock_open, mock_os, mock_json): + mock_os.path.isfile.return_value = False + + display.display(display.Tensorboard('gs://job/dir')) + + mock_json.dump.assert_called_with({ + 'outputs': [{ + 'type': 'tensorboard', + 'source': 'gs://job/dir' + }] + }, mock.ANY) + + def test_display_link(self, mock_open, mock_os, mock_json): + mock_os.path.isfile.return_value = False + + display.display(display.Link('https://test/link', 'Test Link')) + + mock_json.dump.assert_called_with({ + 'outputs': [{ + 'type': 'web-app', + 'html': 'Test Link' + }] + }, mock.ANY) diff --git a/component_sdk/python/tests/core/test__kfp_execution_context.py b/component_sdk/python/tests/core/test__kfp_execution_context.py index 2789809f964..cfc44f043db 100644 --- a/component_sdk/python/tests/core/test__kfp_execution_context.py +++ b/component_sdk/python/tests/core/test__kfp_execution_context.py @@ -24,7 +24,7 @@ @mock.patch('kubernetes.config.load_incluster_config') @mock.patch('kubernetes.client.CoreV1Api') -class BaseOpTest(unittest.TestCase): +class KfpExecutionContextTest(unittest.TestCase): def test_init_succeed_without_pod_name(self, mock_k8s_client, mock_load_config): diff --git a/component_sdk/python/tests/google/__init__.py b/component_sdk/python/tests/google/__init__.py new file mode 100644 index 00000000000..c2fc82ab83f --- /dev/null +++ b/component_sdk/python/tests/google/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/component_sdk/python/tests/google/ml_engine/__init__.py b/component_sdk/python/tests/google/ml_engine/__init__.py new file mode 100644 index 00000000000..c2fc82ab83f --- /dev/null +++ b/component_sdk/python/tests/google/ml_engine/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/component_sdk/python/tests/google/ml_engine/test__create_job.py b/component_sdk/python/tests/google/ml_engine/test__create_job.py new file mode 100644 index 00000000000..309e1a92715 --- /dev/null +++ b/component_sdk/python/tests/google/ml_engine/test__create_job.py @@ -0,0 +1,178 @@ + +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import mock +import unittest + +from googleapiclient import errors +from kfp_component.google.ml_engine import create_job + +CREATE_JOB_MODULE = 'kfp_component.google.ml_engine._create_job' + +@mock.patch(CREATE_JOB_MODULE + '.display.display') +@mock.patch(CREATE_JOB_MODULE + '.gcp_common.dump_file') +@mock.patch(CREATE_JOB_MODULE + '.KfpExecutionContext') +@mock.patch(CREATE_JOB_MODULE + '.MLEngineClient') +class TestCreateJob(unittest.TestCase): + + def test_create_job_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + mock_kfp_context().__enter__().context_id.return_value = 'ctx1' + job = {} + returned_job = { + 'jobId': 'job_ctx1', + 'state': 'SUCCEEDED' + } + mock_mlengine_client().get_job.return_value = ( + returned_job) + + result = create_job('mock_project', job) + + self.assertEqual(returned_job, result) + mock_mlengine_client().create_job.assert_called_with( + project_id = 'mock_project', + job = { + 'jobId': 'job_ctx1' + } + ) + + def test_create_job_with_job_id_prefix_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + mock_kfp_context().__enter__().context_id.return_value = 'ctx1' + job = {} + returned_job = { + 'jobId': 'mock_job_ctx1', + 'state': 'SUCCEEDED' + } + mock_mlengine_client().get_job.return_value = ( + returned_job) + + result = create_job('mock_project', job, job_id_prefix='mock_job_') + + self.assertEqual(returned_job, result) + mock_mlengine_client().create_job.assert_called_with( + project_id = 'mock_project', + job = { + 'jobId': 'mock_job_ctx1' + } + ) + + def test_execute_retry_job_success(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + mock_kfp_context().__enter__().context_id.return_value = 'ctx1' + job = {} + returned_job = { + 'jobId': 'job_ctx1', + 'state': 'SUCCEEDED' + } + mock_mlengine_client().create_job.side_effect = errors.HttpError( + resp = mock.Mock(status=409), + content = b'conflict' + ) + mock_mlengine_client().get_job.return_value = returned_job + + result = create_job('mock_project', job) + + self.assertEqual(returned_job, result) + + def test_create_job_use_context_id_as_name(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + context_id = 'ctx1' + job = {} + returned_job = { + 'jobId': 'job_ctx1', + 'state': 'SUCCEEDED' + } + mock_mlengine_client().get_job.return_value = ( + returned_job) + mock_kfp_context().__enter__().context_id.return_value = context_id + + create_job('mock_project', job) + + mock_mlengine_client().create_job.assert_called_with( + project_id = 'mock_project', + job = { + 'jobId': 'job_ctx1' + } + ) + + def test_execute_conflict_fail(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + mock_kfp_context().__enter__().context_id.return_value = 'ctx1' + job = {} + returned_job = { + 'jobId': 'job_ctx1', + 'trainingInput': { + 'modelDir': 'test' + }, + 'state': 'SUCCEEDED' + } + mock_mlengine_client().create_job.side_effect = errors.HttpError( + resp = mock.Mock(status=409), + content = b'conflict' + ) + mock_mlengine_client().get_job.return_value = returned_job + + with self.assertRaises(errors.HttpError) as context: + create_job('mock_project', job) + + self.assertEqual(409, context.exception.resp.status) + + def test_execute_create_job_fail(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + mock_kfp_context().__enter__().context_id.return_value = 'ctx1' + job = {} + mock_mlengine_client().create_job.side_effect = errors.HttpError( + resp = mock.Mock(status=400), + content = b'bad request' + ) + + with self.assertRaises(errors.HttpError) as context: + create_job('mock_project', job) + + self.assertEqual(400, context.exception.resp.status) + + def test_execute_job_status_fail(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + mock_kfp_context().__enter__().context_id.return_value = 'ctx1' + job = {} + returned_job = { + 'jobId': 'mock_job', + 'trainingInput': { + 'modelDir': 'test' + }, + 'state': 'FAILED' + } + mock_mlengine_client().get_job.return_value = returned_job + + with self.assertRaises(RuntimeError): + create_job('mock_project', job) + + def test_cancel_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + mock_kfp_context().__enter__().context_id.return_value = 'ctx1' + job = {} + returned_job = { + 'jobId': 'job_ctx1', + 'state': 'SUCCEEDED' + } + mock_mlengine_client().get_job.return_value = ( + returned_job) + create_job('mock_project', job) + cancel_func = mock_kfp_context.call_args[1]['on_cancel'] + + cancel_func() + + mock_mlengine_client().cancel_job.assert_called_with( + 'mock_project', 'job_ctx1' + ) diff --git a/component_sdk/python/tests/google/ml_engine/test__create_model.py b/component_sdk/python/tests/google/ml_engine/test__create_model.py new file mode 100644 index 00000000000..9322f934b28 --- /dev/null +++ b/component_sdk/python/tests/google/ml_engine/test__create_model.py @@ -0,0 +1,90 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest + +from googleapiclient import errors +from kfp_component.google.ml_engine import create_model + +CREATE_MODEL_MODULE = 'kfp_component.google.ml_engine._create_model' + +@mock.patch(CREATE_MODEL_MODULE + '.display.display') +@mock.patch(CREATE_MODEL_MODULE + '.gcp_common.dump_file') +@mock.patch(CREATE_MODEL_MODULE + '.KfpExecutionContext') +@mock.patch(CREATE_MODEL_MODULE + '.MLEngineClient') +class TestCreateModel(unittest.TestCase): + + def test_create_model_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + model = { + 'name': 'mock_model', + 'description': 'the mock model' + } + mock_mlengine_client().create_model.return_value = model + + result = create_model('mock_project', 'mock_model', model) + + self.assertEqual(model, result) + + def test_create_model_conflict_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + model = { + 'name': 'mock_model', + 'description': 'the mock model' + } + mock_mlengine_client().create_model.side_effect = errors.HttpError( + resp = mock.Mock(status=409), + content = b'conflict' + ) + mock_mlengine_client().get_model.return_value = model + + result = create_model('mock_project', 'mock_model', model) + + self.assertEqual(model, result) + + def test_create_model_conflict_fail(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + model = { + 'name': 'mock_model', + 'description': 'the mock model' + } + mock_mlengine_client().create_model.side_effect = errors.HttpError( + resp = mock.Mock(status=409), + content = b'conflict' + ) + changed_model = { + 'name': 'mock_model', + 'description': 'the changed mock model' + } + mock_mlengine_client().get_model.return_value = changed_model + + with self.assertRaises(errors.HttpError) as context: + create_model('mock_project', 'mock_model', model) + + self.assertEqual(409, context.exception.resp.status) + + def test_create_model_use_context_id_as_name(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + context_id = 'context1' + model = {} + returned_model = { + 'name': 'model_' + context_id + } + mock_mlengine_client().create_model.return_value = returned_model + mock_kfp_context().__enter__().context_id.return_value = context_id + + create_model('mock_project', model=model) + + mock_mlengine_client().create_model.assert_called_with( + project_id = 'mock_project', + model = returned_model + ) \ No newline at end of file diff --git a/component_sdk/python/tests/google/ml_engine/test__create_version.py b/component_sdk/python/tests/google/ml_engine/test__create_version.py new file mode 100644 index 00000000000..053fb69194e --- /dev/null +++ b/component_sdk/python/tests/google/ml_engine/test__create_version.py @@ -0,0 +1,211 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest + +from googleapiclient import errors +from kfp_component.google.ml_engine import create_version + +CREATE_VERSION_MODULE = 'kfp_component.google.ml_engine._create_version' + +@mock.patch(CREATE_VERSION_MODULE + '.display.display') +@mock.patch(CREATE_VERSION_MODULE + '.gcp_common.dump_file') +@mock.patch(CREATE_VERSION_MODULE + '.KfpExecutionContext') +@mock.patch(CREATE_VERSION_MODULE + '.MLEngineClient') +class TestCreateVersion(unittest.TestCase): + + def test_create_version_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + version = { + 'description': 'the mock version' + } + mock_mlengine_client().get_version.return_value = None + mock_mlengine_client().create_version.return_value = { + 'name': 'mock_operation_name' + } + mock_mlengine_client().get_operation.return_value = { + 'done': True, + 'response': version + } + + result = create_version('mock_project', 'mock_model', + deployemnt_uri = 'gs://test-location', version_name = 'mock_version', + version = version, + replace_existing = True) + + self.assertEqual(version, result) + + def test_create_version_fail(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + version = { + 'name': 'mock_version', + 'description': 'the mock version', + 'deploymentUri': 'gs://test-location' + } + mock_mlengine_client().get_version.return_value = None + mock_mlengine_client().create_version.return_value = { + 'name': 'mock_operation_name' + } + mock_mlengine_client().get_operation.return_value = { + 'done': True, + 'error': { + 'code': 400, + 'message': 'bad request' + } + } + + with self.assertRaises(RuntimeError) as context: + create_version('mock_project', 'mock_model', + version = version, replace_existing = True, wait_interval = 30) + + self.assertEqual( + 'Failed to complete create version operation mock_operation_name: 400 bad request', + str(context.exception)) + + def test_create_version_dup_version_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + version = { + 'name': 'mock_version', + 'description': 'the mock version', + 'deploymentUri': 'gs://test-location' + } + pending_version = { + 'state': 'CREATING' + } + pending_version.update(version) + ready_version = { + 'state': 'READY' + } + ready_version.update(version) + mock_mlengine_client().get_version.side_effect = [ + pending_version, ready_version] + + result = create_version('mock_project', 'mock_model', version = version, + replace_existing = True, wait_interval = 0) + + self.assertEqual(ready_version, result) + + def test_create_version_failed_state(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + version = { + 'name': 'mock_version', + 'description': 'the mock version', + 'deploymentUri': 'gs://test-location' + } + pending_version = { + 'state': 'CREATING' + } + pending_version.update(version) + failed_version = { + 'state': 'FAILED', + 'errorMessage': 'something bad happens' + } + failed_version.update(version) + mock_mlengine_client().get_version.side_effect = [ + pending_version, failed_version] + + with self.assertRaises(RuntimeError) as context: + create_version('mock_project', 'mock_model', version = version, + replace_existing = True, wait_interval = 0) + + self.assertEqual( + 'Version is in failed state: something bad happens', + str(context.exception)) + + def test_create_version_conflict_version_replace_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + version = { + 'name': 'mock_version', + 'description': 'the mock version', + 'deploymentUri': 'gs://test-location' + } + conflicting_version = { + 'name': 'mock_version', + 'description': 'the changed mock version', + 'deploymentUri': 'gs://changed-test-location', + 'state': 'READY' + } + mock_mlengine_client().get_version.return_value = conflicting_version + mock_mlengine_client().delete_version.return_value = { + 'name': 'delete_operation_name' + } + mock_mlengine_client().create_version.return_value = { + 'name': 'create_operation_name' + } + delete_operation = { 'response': {}, 'done': True } + create_operation = { 'response': version, 'done': True } + mock_mlengine_client().get_operation.side_effect = [ + delete_operation, + create_operation + ] + + result = create_version('mock_project', 'mock_model', version = version, + replace_existing = True, wait_interval = 0) + + self.assertEqual(version, result) + + def test_create_version_conflict_version_delete_fail(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + version = { + 'name': 'mock_version', + 'description': 'the mock version', + 'deploymentUri': 'gs://test-location' + } + conflicting_version = { + 'name': 'mock_version', + 'description': 'the changed mock version', + 'deploymentUri': 'gs://changed-test-location', + 'state': 'READY' + } + mock_mlengine_client().get_version.return_value = conflicting_version + mock_mlengine_client().delete_version.return_value = { + 'name': 'delete_operation_name' + } + delete_operation = { + 'done': True, + 'error': { + 'code': 400, + 'message': 'bad request' + } + } + mock_mlengine_client().get_operation.return_value = delete_operation + + with self.assertRaises(RuntimeError) as context: + create_version('mock_project', 'mock_model', version = version, + replace_existing = True, wait_interval = 0) + + self.assertEqual( + 'Failed to complete delete version operation delete_operation_name: 400 bad request', + str(context.exception)) + + def test_create_version_conflict_version_fail(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json, mock_display): + version = { + 'name': 'mock_version', + 'description': 'the mock version', + 'deploymentUri': 'gs://test-location' + } + conflicting_version = { + 'name': 'mock_version', + 'description': 'the changed mock version', + 'deploymentUri': 'gs://changed-test-location', + 'state': 'READY' + } + mock_mlengine_client().get_version.return_value = conflicting_version + + with self.assertRaises(RuntimeError) as context: + create_version('mock_project', 'mock_model', version = version, + replace_existing = False, wait_interval = 0) + + self.assertEqual( + 'Existing version conflicts with the name of the new version.', + str(context.exception)) \ No newline at end of file diff --git a/component_sdk/python/tests/google/ml_engine/test__delete_version.py b/component_sdk/python/tests/google/ml_engine/test__delete_version.py new file mode 100644 index 00000000000..5976ef3e448 --- /dev/null +++ b/component_sdk/python/tests/google/ml_engine/test__delete_version.py @@ -0,0 +1,52 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest + +from googleapiclient import errors +from kfp_component.google.ml_engine import delete_version + +DELETE_VERSION_MODULE = 'kfp_component.google.ml_engine._delete_version' + +@mock.patch(DELETE_VERSION_MODULE + '.gcp_common.dump_file') +@mock.patch(DELETE_VERSION_MODULE + '.KfpExecutionContext') +@mock.patch(DELETE_VERSION_MODULE + '.MLEngineClient') +class TestDeleteVersion(unittest.TestCase): + + def test_execute_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json): + mock_mlengine_client().get_version.return_value = { + 'state': 'READY', + } + mock_mlengine_client().delete_version.return_value = { + 'name': 'mock_operation_name' + } + mock_mlengine_client().get_operation.return_value = { + 'done': True + } + + delete_version('mock_project', 'mock_model', 'mock_version', + wait_interval = 30) + + mock_mlengine_client().delete_version.assert_called_once() + + def test_execute_retry_succeed(self, mock_mlengine_client, + mock_kfp_context, mock_dump_json): + pending_version = { + 'state': 'DELETING', + } + mock_mlengine_client().get_version.side_effect = [pending_version, None] + + delete_version('mock_project', 'mock_model', 'mock_version', + wait_interval = 0) + + self.assertEqual(2, mock_mlengine_client().get_version.call_count) \ No newline at end of file diff --git a/component_sdk/python/tests/google/ml_engine/test__train.py b/component_sdk/python/tests/google/ml_engine/test__train.py new file mode 100644 index 00000000000..e6716fdd422 --- /dev/null +++ b/component_sdk/python/tests/google/ml_engine/test__train.py @@ -0,0 +1,44 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import mock +import unittest + +from kfp_component.google.ml_engine import train + +CREATE_JOB_MODULE = 'kfp_component.google.ml_engine._train' + +@mock.patch(CREATE_JOB_MODULE + '.create_job') +class TestCreateTraingingJob(unittest.TestCase): + + def test_train_succeed(self, mock_create_job): + train('proj-1', 'mock.module', ['gs://test/package'], + 'region-1', args=['arg-1', 'arg-2'], job_dir='gs://test/job/dir', + training_input={ + 'runtimeVersion': '1.10', + 'pythonVersion': '2.7' + }, job_id_prefix='job-') + + mock_create_job.assert_called_with('proj-1', { + 'trainingInput': { + 'pythonModule': 'mock.module', + 'packageUris': ['gs://test/package'], + 'region': 'region-1', + 'args': ['arg-1', 'arg-2'], + 'jobDir': 'gs://test/job/dir', + 'runtimeVersion': '1.10', + 'pythonVersion': '2.7' + } + }, 'job-', 30)