diff --git a/sdk/python/kfp/_client.py b/sdk/python/kfp/_client.py index 53de375adf95..938721ff46c0 100644 --- a/sdk/python/kfp/_client.py +++ b/sdk/python/kfp/_client.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import string -import random import time import logging import json @@ -24,7 +22,7 @@ import warnings import yaml import zipfile -from datetime import datetime +import datetime from typing import Mapping, Callable import kfp @@ -35,6 +33,12 @@ from kfp._auth import get_auth_token, get_gcp_access_token +# TTL of the access token associated with the client. This is needed because +# `gcloud auth print-access-token` generates a token with TTL=1 hour, after +# which the authentication expires. This TTL is needed for kfp.Client() +# initialized with host=. +# Set to 55 mins to provide some safe margin. +_GCP_ACCESS_TOKEN_TIMEOUT = datetime.timedelta(minutes=55) def _add_generated_apis(target_struct, api_module, api_client): @@ -108,6 +112,9 @@ def __init__(self, host=None, client_id=None, namespace='kubeflow', other_client host = host or os.environ.get(KF_PIPELINES_ENDPOINT_ENV) self._uihost = os.environ.get(KF_PIPELINES_UI_ENDPOINT_ENV, host) config = self._load_config(host, client_id, namespace, other_client_id, other_client_secret, existing_token) + # Save the loaded API client configuration, as a reference if update is + # needed. + self._existing_config = config api_client = kfp_server_api.api_client.ApiClient(config) _add_generated_apis(self, kfp_server_api, api_client) self._job_api = kfp_server_api.api.job_service_api.JobServiceApi(api_client) @@ -150,10 +157,13 @@ def _load_config(self, host, client_id, namespace, other_client_id, other_client # if existing_token: token = existing_token + self._is_refresh_token = False elif client_id: token = get_auth_token(client_id, other_client_id, other_client_secret) + self._is_refresh_token = True elif self._is_inverse_proxy_host(host): token = get_gcp_access_token() + self._is_refresh_token = False if token: config.api_key['authorization'] = token @@ -226,6 +236,14 @@ def _load_context_setting_or_default(self): self._context_setting = { 'namespace': '', } + + def _refresh_api_client_token(self): + """Refreshes the existing token associated with the kfp_api_client.""" + if getattr(self, '_is_refresh_token'): + return + + new_token = get_gcp_access_token() + self._existing_config.api_key['authorization'] = new_token def set_user_namespace(self, namespace): """Set user namespace into local context setting file. @@ -531,7 +549,7 @@ def create_run_from_pipeline_func(self, pipeline_func: Callable, arguments: Mapp ''' #TODO: Check arguments against the pipeline function pipeline_name = pipeline_func.__name__ - run_name = run_name or pipeline_name + ' ' + datetime.now().strftime('%Y-%m-%d %H-%M-%S') + run_name = run_name or pipeline_name + ' ' + datetime.datetime.now().strftime('%Y-%m-%d %H-%M-%S') with tempfile.TemporaryDirectory() as tmpdir: pipeline_package_path = os.path.join(tmpdir, 'pipeline.yaml') compiler.Compiler().compile(pipeline_func, pipeline_package_path, pipeline_conf=pipeline_conf) @@ -558,7 +576,7 @@ def __init__(self, client, run_info): self.run_id = run_info.id def wait_for_run_completion(self, timeout=None): - timeout = timeout or datetime.datetime.max - datetime.datetime.min + timeout = timeout or datetime.timedelta.max return self._client.wait_for_run_completion(self.run_id, timeout) def __repr__(self): @@ -572,7 +590,9 @@ def __repr__(self): import warnings warnings.warn('Changing experiment name from "{}" to "{}".'.format(experiment_name, overridden_experiment_name)) experiment_name = overridden_experiment_name or 'Default' - run_name = run_name or pipeline_name + ' ' + datetime.now().strftime('%Y-%m-%d %H-%M-%S') + run_name = run_name or (pipeline_name + ' ' + + datetime.datetime.now().strftime( + '%Y-%m-%d %H-%M-%S')) experiment = self.create_experiment(name=experiment_name, namespace=namespace) run_info = self.run_pipeline(experiment.id, run_name, pipeline_file, arguments) return RunPipelineResult(self, run_info) @@ -639,19 +659,30 @@ def get_run(self, run_id): return self._run_api.get_run(run_id=run_id) def wait_for_run_completion(self, run_id, timeout): - """Wait for a run to complete. + """Waits for a run to complete. Args: run_id: run id, returned from run_pipeline. timeout: timeout in seconds. Returns: - A run detail object: Most important fields are run and pipeline_runtime + A run detail object: Most important fields are run and pipeline_runtime. + Raises: + TimeoutError: if the pipeline run failed to finish before the specified + timeout. """ status = 'Running:' - start_time = datetime.now() - while status is None or status.lower() not in ['succeeded', 'failed', 'skipped', 'error']: + start_time = datetime.datetime.now() + last_token_refresh_time = datetime.datetime.now() + while (status is None or + status.lower() not in ['succeeded', 'failed', 'skipped', 'error']): + # Refreshes the access token before it hits the TTL. + if (datetime.datetime.now() - last_token_refresh_time + > _GCP_ACCESS_TOKEN_TIMEOUT): + self._refresh_api_client_token() + last_token_refresh_time = datetime.datetime.now() + get_run_response = self._run_api.get_run(run_id=run_id) status = get_run_response.run.status - elapsed_time = (datetime.now() - start_time).seconds + elapsed_time = (datetime.datetime.now() - start_time).seconds logging.info('Waiting for the job to complete...') if elapsed_time > timeout: raise TimeoutError('Run timeout')