diff --git a/metaflow/plugins/aws/eks/kubernetes.py b/metaflow/plugins/aws/eks/kubernetes.py index f1acfb30fd4..8c2a945b0c0 100644 --- a/metaflow/plugins/aws/eks/kubernetes.py +++ b/metaflow/plugins/aws/eks/kubernetes.py @@ -358,9 +358,7 @@ def _print_available(tail, stream, should_persist=False): "%s. This could be a transient error. " "Use @retry to retry." % msg ) - elif not self._job.is_done: - # Kill the job if it is still running by throwing an exception. - raise KubernetesKilledException("Task failed!") + exit_code, _ = self._job.reason echo( "Task finished with exit code %s." % exit_code, diff --git a/metaflow/plugins/aws/eks/kubernetes_client.py b/metaflow/plugins/aws/eks/kubernetes_client.py index 56521720cce..e56c2f68de0 100644 --- a/metaflow/plugins/aws/eks/kubernetes_client.py +++ b/metaflow/plugins/aws/eks/kubernetes_client.py @@ -1,4 +1,5 @@ import os +import time try: unicode @@ -8,6 +9,8 @@ from metaflow.exception import MetaflowException +CLIENT_REFRESH_INTERVAL_SECONDS = 300 + class KubernetesJobException(MetaflowException): headline = "Kubernetes job error" @@ -28,6 +31,11 @@ def __init__(self): "Could not import module 'kubernetes'. Install kubernetes " "Python package (https://pypi.org/project/kubernetes/) first." ) + self._refresh_client() + + def _refresh_client(self): + from kubernetes import client, config + if os.getenv("KUBERNETES_SERVICE_HOST"): # We’re inside a pod, authenticate via ServiceAccount assigned to us config.load_incluster_config() @@ -40,14 +48,24 @@ def __init__(self): # good enough for the initial rollout. config.load_kube_config() self._client = client + self._client_refresh_timestamp = time.time() def job(self, **kwargs): - return KubernetesJob(self._client, **kwargs) + return KubernetesJob(self, **kwargs) + + def get(self): + if ( + time.time() - self._client_refresh_timestamp + > CLIENT_REFRESH_INTERVAL_SECONDS + ): + self._refresh_client() + + return self._client class KubernetesJob(object): - def __init__(self, client, **kwargs): - self._client = client + def __init__(self, client_wrapper, **kwargs): + self._client_wrapper = client_wrapper self._kwargs = kwargs # Kubernetes namespace defaults to `default` @@ -104,10 +122,11 @@ def create(self): # # Note: This implementation ensures that there is only one unique Pod # (unique UID) per Metaflow task attempt. - self._job = self._client.V1Job( + client = self._client_wrapper.get() + self._job = client.V1Job( api_version="batch/v1", kind="Job", - metadata=self._client.V1ObjectMeta( + metadata=client.V1ObjectMeta( # Annotations are for humans annotations=self._kwargs.get("annotations", {}), # While labels are for Kubernetes @@ -115,7 +134,7 @@ def create(self): name=self._kwargs["name"], # Unique within the namespace namespace=self._kwargs["namespace"], # Defaults to `default` ), - spec=self._client.V1JobSpec( + spec=client.V1JobSpec( # Retries are handled by Metaflow when it is responsible for # executing the flow. The responsibility is moved to Kubernetes # when AWS Step Functions / Argo are responsible for the @@ -128,18 +147,16 @@ def create(self): * 60 * 60 # Remove job after a week. TODO (savin): Make this * 24, # configurable - template=self._client.V1PodTemplateSpec( - metadata=self._client.V1ObjectMeta( + template=client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta( annotations=self._kwargs.get("annotations", {}), labels=self._kwargs.get("labels", {}), name=self._kwargs["name"], namespace=self._kwargs["namespace"], ), - spec=self._client.V1PodSpec( + spec=client.V1PodSpec( # Timeout is set on the pod and not the job (important!) - active_deadline_seconds=self._kwargs[ - "timeout_in_seconds" - ], + active_deadline_seconds=self._kwargs["timeout_in_seconds"], # TODO (savin): Enable affinities for GPU scheduling. # This requires some thought around the # UX since specifying affinities can get @@ -148,10 +165,10 @@ def create(self): # roll out. # affinity=?, containers=[ - self._client.V1Container( + client.V1Container( command=self._kwargs["command"], env=[ - self._client.V1EnvVar(name=k, value=str(v)) + client.V1EnvVar(name=k, value=str(v)) for k, v in self._kwargs.get( "environment_variables", {} ).items() @@ -163,10 +180,10 @@ def create(self): # TODO: Figure out a way to make job # metadata visible within the container + [ - self._client.V1EnvVar( + client.V1EnvVar( name=k, - value_from=self._client.V1EnvVarSource( - field_ref=self._client.V1ObjectFieldSelector( + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( field_path=str(v) ) ), @@ -178,20 +195,17 @@ def create(self): }.items() ], env_from=[ - self._client.V1EnvFromSource( - secret_ref=self._client.V1SecretEnvSource( - name=str(k) - ) + client.V1EnvFromSource( + secret_ref=client.V1SecretEnvSource(name=str(k)) ) for k in self._kwargs.get("secrets", []) ], image=self._kwargs["image"], name=self._kwargs["name"], - resources=self._client.V1ResourceRequirements( + resources=client.V1ResourceRequirements( requests={ "cpu": str(self._kwargs["cpu"]), - "memory": "%sM" - % str(self._kwargs["memory"]), + "memory": "%sM" % str(self._kwargs["memory"]), "ephemeral-storage": "%sM" % str(self._kwargs["disk"]), } @@ -239,25 +253,26 @@ def create(self): return self def execute(self): + client = self._client_wrapper.get() try: # TODO (savin): Make job submission back-pressure aware. Currently # there doesn't seem to be a kubernetes-native way to # achieve the guarantees that we are seeking. # Hopefully, we will be able to get creative soon. response = ( - self._client.BatchV1Api() + client.BatchV1Api() .create_namespaced_job( body=self._job, namespace=self._kwargs["namespace"] ) .to_dict() ) return RunningJob( - client=self._client, + client_wrapper=self._client_wrapper, name=response["metadata"]["name"], uid=response["metadata"]["uid"], namespace=response["metadata"]["namespace"], ) - except self._client.rest.ApiException as e: + except client.rest.ApiException as e: raise KubernetesJobException( "Unable to launch Kubernetes job.\n %s" % str(e) ) @@ -293,9 +308,7 @@ def environment_variable(self, name, value): return self def label(self, name, value): - self._kwargs["labels"] = dict( - self._kwargs.get("labels", {}), **{name: value} - ) + self._kwargs["labels"] = dict(self._kwargs.get("labels", {}), **{name: value}) return self def annotation(self, name, value): @@ -309,6 +322,22 @@ class RunningJob(object): # State Machine implementation for the lifecycle behavior documented in # https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/ + # + # This object encapsulates *both* V1Job and V1Pod. It simplifies the status + # to "running" and "done" (failed/succeeded) state. Note that V1Job and V1Pod + # status fields are not guaranteed to be always in sync due to the way job + # controller works. + # + # For example, for a successful job, RunningJob states and their corresponding + # K8S object states look like this: + # + # | V1JobStatus.active | V1JobStatus.succeeded | V1PodStatus.phase | RunningJob.is_running | RunningJob.is_done | + # |--------------------|-----------------------|-------------------|-----------------------|--------------------| + # | 0 | 0 | N/A | False | False | + # | 0 | 0 | Pending | False | False | + # | 1 | 0 | Running | True | False | + # | 1 | 0 | Succeeded | True | True | + # | 0 | 1 | Succeeded | False | True | # To ascertain the status of V1Job, we peer into the lifecycle status of # the pod it is responsible for executing. Unfortunately, the `phase` @@ -348,8 +377,8 @@ class RunningJob(object): JOB_ACTIVE = "job:active" JOB_FAILED = "" - def __init__(self, client, name, uid, namespace): - self._client = client + def __init__(self, client_wrapper, name, uid, namespace): + self._client_wrapper = client_wrapper self._name = name self._id = uid self._namespace = namespace @@ -367,33 +396,35 @@ def __repr__(self): ) def _fetch_job(self): + client = self._client_wrapper.get() try: return ( - self._client.BatchV1Api() + client.BatchV1Api() .read_namespaced_job(name=self._name, namespace=self._namespace) .to_dict() ) - except self._client.rest.ApiException as e: + except client.rest.ApiException as e: # TODO: Handle failures as well as the fact that a different # process can delete the job. raise e def _fetch_pod(self): - try: - # TODO (savin): pods may not appear immediately or they may - # disappear - return ( - self._client.CoreV1Api() - .list_namespaced_pod( - namespace=self._namespace, - label_selector="job-name={}".format(self._name), - ) - .to_dict()["items"] - or [None] - )[0] - except self._client.rest.ApiException as e: - # TODO: Handle failures - raise e + """Fetch pod metadata. May return None if pod does not exist.""" + client = self._client_wrapper.get() + + pods = ( + client.CoreV1Api() + .list_namespaced_pod( + namespace=self._namespace, + label_selector="job-name={}".format(self._name), + ) + .to_dict()["items"] + ) + + if pods: + return pods[0] + else: + return None def kill(self): # Terminating a Kubernetes job is a bit tricky. Issuing a @@ -418,12 +449,18 @@ def kill(self): # terminate the pod without deleting the object. # 3. If the pod object hasn't shown up yet, we set the parallelism to 0 # to preempt it. - if not self.is_done: - if self.is_running: + client = self._client_wrapper.get() + if not self._check_is_done(): + if self._check_is_running(): + + # Unless there is a bug in the code, self._pod cannot be None + # if we're in "running" state. + assert self._pod is not None + # Case 1. from kubernetes.stream import stream - api_instance = self._client.CoreV1Api + api_instance = client.CoreV1Api try: # TODO (savin): stream opens a web-socket connection. It may # not be desirable to open multiple web-socket @@ -454,7 +491,7 @@ def kill(self): try: # TODO (savin): Also patch job annotation to reflect this # action. - self._client.BatchV1Api().patch_namespaced_job( + client.BatchV1Api().patch_namespaced_job( name=self._name, namespace=self._namespace, field_manager="metaflow", @@ -472,15 +509,11 @@ def id(self): # TODO (savin): Should we use pod id instead? return self._id - @property - def is_done(self): - def _done(): + def _check_is_done(self): + def _job_done(): # Either the job succeeds or fails naturally or we may have # forced the pod termination causing the job to still be in an # active state but for all intents and purposes dead to us. - # - # This method relies exclusively on the state of V1Job object, - # since it's guaranteed to exist during the lifetime of the job. # TODO (savin): check for self._job return ( @@ -489,15 +522,23 @@ def _done(): or (self._job["spec"]["parallelism"] == 0) ) - if not _done(): + if not _job_done(): # If not done, check for newer status self._job = self._fetch_job() - return _done() + if _job_done(): + return True + else: + # It is possible for the job metadata to not be updated yet, but the + # Pod has already succeeded or failed. + self._pod = self._fetch_pod() + if self._pod and (self._pod["status"]["phase"] in ("Succeeded", "Failed")): + return True + else: + return False - @property - def status(self): - if not self.is_done: - # If not done, check for newer status + def _get_status(self): + if not self._check_is_done(): + # If not done, check for newer pod status self._pod = self._fetch_pod() # Success! if bool(self._job["status"].get("succeeded")): @@ -534,38 +575,48 @@ def status(self): return msg return "Job:Unknown" - @property - def has_succeeded(self): + def _check_has_succeeded(self): # Job is in a terminal state and the status is marked as succeeded - return self.is_done and bool(self._job["status"].get("succeeded")) + if self._check_is_done(): + if bool(self._job["status"].get("succeeded")) or ( + self._pod and self._pod["status"]["phase"] == "Succeeded" + ): + return True + else: + return False + else: + return False - @property - def has_failed(self): + def _check_has_failed(self): # Job is in a terminal state and either the status is marked as failed # or the Job is not allowed to launch any more pods - return self.is_done and ( - bool(self._job["status"].get("failed")) - or (self._job["spec"]["parallelism"] == 0) - ) - @property - def is_running(self): - # The container is running. This happens when the Pod's phase is running - if not self.is_done: + if self._check_is_done(): + if ( + bool(self._job["status"].get("failed")) + or (self._job["spec"]["parallelism"] == 0) + or (self._pod and self._pod["status"]["phase"] == "Failed") + ): + return True + else: + return False + else: + return False + + def _check_is_running(self): + # Returns true if the container is running. + if not self._check_is_done(): # If not done, check if pod has been assigned and is in Running # phase - self._pod = self._fetch_pod() - return self._pod.get("status", {}).get("phase") == "Running" + if self._pod is None: + return False + pod_phase = self._pod.get("status", {}).get("phase") + return pod_phase == "Running" return False - @property - def is_waiting(self): - return not self.is_done and not self.is_running - - @property - def reason(self): - if self.is_done: - if self.has_succeeded: + def _get_done_reason(self): + if self._check_is_done(): + if self._check_has_succeeded(): return 0, None # Best effort since Pod object can disappear on us at anytime else: @@ -585,15 +636,14 @@ def _done(): # We're done, but no container_statuses is set # This can happen when the pod is evicted return None, ": ".join( - filter( - None, - [pod_status.get("reason"), pod_status.get("message")], - ) + filter( + None, + [pod_status.get("reason"), pod_status.get("message")], ) + ) for k, v in ( - pod_status - .get("container_statuses", [{}])[0] + pod_status.get("container_statuses", [{}])[0] .get("state", {}) .items() ): @@ -606,3 +656,23 @@ def _done(): ) return None, None + + @property + def is_done(self): + return self._check_is_done() + + @property + def has_failed(self): + return self._check_has_failed() + + @property + def is_running(self): + return self._check_is_running() + + @property + def reason(self): + return self._get_done_reason() + + @property + def status(self): + return self._get_status()