diff --git a/metaflow/plugins/aws/eks/kubernetes.py b/metaflow/plugins/aws/eks/kubernetes.py index f1acfb30fd4..37aa00baddb 100644 --- a/metaflow/plugins/aws/eks/kubernetes.py +++ b/metaflow/plugins/aws/eks/kubernetes.py @@ -260,7 +260,7 @@ def wait(self, echo=None): stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr") def wait_for_launch(job): - status = job.status + status = job.get_status() echo( "Task is starting (Status %s)..." % status, "stderr", @@ -268,7 +268,7 @@ def wait_for_launch(job): ) t = time.time() while True: - new_status = job.status + new_status = job.get_status() if status != new_status or (time.time() - t) > 30: status = new_status echo( @@ -277,7 +277,7 @@ def wait_for_launch(job): job_id=job.id, ) t = time.time() - if job.is_running or job.is_done: + if job.check_pod_is_running() or job.check_is_done(): break time.sleep(1) @@ -317,7 +317,7 @@ def _print_available(tail, stream, should_persist=False): now = time.time() log_update_delay = update_delay(now - start_time) next_log_update = now + log_update_delay - is_running = self._job.is_running + is_running = self._job.check_pod_is_running() # This sleep should never delay log updates. On the other hand, # we should exit this loop when the task has finished without @@ -337,8 +337,32 @@ def _print_available(tail, stream, should_persist=False): _print_available(stdout_tail, "stdout") _print_available(stderr_tail, "stderr") - if self._job.has_failed: - exit_code, reason = self._job.reason + try: + # Now the pod is no longer running, but we need to wait for the job + # status to update. *Usually* that happens pretty much immediately, + # but since it is technically done by the job controller + # asynchronously, there's room for a race condition where pod is + # done but the job is still "active". It is more likely to happen + # when the control plane is overloaded, e.g. on minikube/kind. + self._job.wait_done(timeout_seconds=20) + except TimeoutError: + # We shouldn't really get here unless the K8S control plane is + # really unhealthy. + echo( + "Pod is not running but the job not is done or failed, last job state: %s" % self._job._job, + "stderr", + job_id=self._job.id, + ) + echo( + "last pod state: %s" % self._job._pod, + "stderr", + job_id=self._job.id, + ) + # Kill the job if it is still running by throwing an exception. + raise KubernetesKilledException("Task failed!") + + if self._job.check_has_failed(): + exit_code, reason = self._job.get_done_reason() msg = next( msg for msg in [ @@ -358,10 +382,8 @@ def _print_available(tail, stream, should_persist=False): "%s. This could be a transient error. " "Use @retry to retry." % msg ) - elif not self._job.is_done: - # Kill the job if it is still running by throwing an exception. - raise KubernetesKilledException("Task failed!") - exit_code, _ = self._job.reason + + exit_code, _ = self._job.get_done_reason() echo( "Task finished with exit code %s." % exit_code, "stderr", diff --git a/metaflow/plugins/aws/eks/kubernetes_client.py b/metaflow/plugins/aws/eks/kubernetes_client.py index 2e7cd0a3e96..0f1c086b781 100644 --- a/metaflow/plugins/aws/eks/kubernetes_client.py +++ b/metaflow/plugins/aws/eks/kubernetes_client.py @@ -1,4 +1,5 @@ import os +import time try: unicode @@ -8,6 +9,12 @@ from metaflow.exception import MetaflowException +# The pod object may not be created immediately after submitting the job, +# we need to have some tolerance here when fetching the pod. +POD_FETCH_BACKOFF_SECONDS = 1.0 +POD_FETCH_RETRIES = 10 + +CLIENT_REFRESH_INTERVAL_SECONDS = 300 class KubernetesJobException(MetaflowException): headline = "Kubernetes job error" @@ -28,6 +35,10 @@ def __init__(self): "Could not import module 'kubernetes'. Install kubernetes " "Python package (https://pypi.org/project/kubernetes/) first." ) + self._refresh_client() + + def _refresh_client(self): + from kubernetes import client, config if os.getenv("KUBERNETES_SERVICE_HOST"): # We’re inside a pod, authenticate via ServiceAccount assigned to us config.load_incluster_config() @@ -40,14 +51,21 @@ def __init__(self): # good enough for the initial rollout. config.load_kube_config() self._client = client + self._client_refresh_timestamp = time.time() def job(self, **kwargs): - return KubernetesJob(self._client, **kwargs) + return KubernetesJob(self, **kwargs) + + def get(self): + if time.time() - self._client_refresh_timestamp < CLIENT_REFRESH_INTERVAL_SECONDS: + self._refresh_client() + + return self._client class KubernetesJob(object): - def __init__(self, client, **kwargs): - self._client = client + def __init__(self, client_wrapper, **kwargs): + self._client_wrapper = client_wrapper self._kwargs = kwargs # Kubernetes namespace defaults to `default` @@ -104,10 +122,11 @@ def create(self): # # Note: This implementation ensures that there is only one unique Pod # (unique UID) per Metaflow task attempt. - self._job = self._client.V1Job( + client = self._client_wrapper.get() + self._job = client.V1Job( api_version="batch/v1", kind="Job", - metadata=self._client.V1ObjectMeta( + metadata=client.V1ObjectMeta( # Annotations are for humans annotations=self._kwargs.get("annotations", {}), # While labels are for Kubernetes @@ -115,7 +134,7 @@ def create(self): name=self._kwargs["name"], # Unique within the namespace namespace=self._kwargs["namespace"], # Defaults to `default` ), - spec=self._client.V1JobSpec( + spec=client.V1JobSpec( # Retries are handled by Metaflow when it is responsible for # executing the flow. The responsibility is moved to Kubernetes # when AWS Step Functions / Argo are responsible for the @@ -128,14 +147,14 @@ def create(self): * 60 * 60 # Remove job after a week. TODO (savin): Make this * 24, # configurable - template=self._client.V1PodTemplateSpec( - metadata=self._client.V1ObjectMeta( + template=client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta( annotations=self._kwargs.get("annotations", {}), labels=self._kwargs.get("labels", {}), name=self._kwargs["name"], namespace=self._kwargs["namespace"], ), - spec=self._client.V1PodSpec( + spec=client.V1PodSpec( # Timeout is set on the pod and not the job (important!) active_deadline_seconds=self._kwargs[ "timeout_in_seconds" @@ -148,10 +167,10 @@ def create(self): # roll out. # affinity=?, containers=[ - self._client.V1Container( + client.V1Container( command=self._kwargs["command"], env=[ - self._client.V1EnvVar(name=k, value=str(v)) + client.V1EnvVar(name=k, value=str(v)) for k, v in self._kwargs.get( "environment_variables", {} ).items() @@ -163,10 +182,10 @@ def create(self): # TODO: Figure out a way to make job # metadata visible within the container + [ - self._client.V1EnvVar( + client.V1EnvVar( name=k, - value_from=self._client.V1EnvVarSource( - field_ref=self._client.V1ObjectFieldSelector( + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( field_path=str(v) ) ), @@ -178,8 +197,8 @@ def create(self): }.items() ], env_from=[ - self._client.V1EnvFromSource( - secret_ref=self._client.V1SecretEnvSource( + client.V1EnvFromSource( + secret_ref=client.V1SecretEnvSource( name=str(k) ) ) @@ -187,7 +206,7 @@ def create(self): ], image=self._kwargs["image"], name=self._kwargs["name"], - resources=self._client.V1ResourceRequirements( + resources=client.V1ResourceRequirements( requests={ "cpu": str(self._kwargs["cpu"]), "memory": "%sM" @@ -239,25 +258,26 @@ def create(self): return self def execute(self): + client = self._client_wrapper.get() try: # TODO (savin): Make job submission back-pressure aware. Currently # there doesn't seem to be a kubernetes-native way to # achieve the guarantees that we are seeking. # Hopefully, we will be able to get creative soon. response = ( - self._client.BatchV1Api() + client.BatchV1Api() .create_namespaced_job( body=self._job, namespace=self._kwargs["namespace"] ) .to_dict() ) return RunningJob( - client=self._client, + client_wrapper=self._client_wrapper, name=response["metadata"]["name"], uid=response["metadata"]["uid"], namespace=response["metadata"]["namespace"], ) - except self._client.rest.ApiException as e: + except client.rest.ApiException as e: raise KubernetesJobException( "Unable to launch Kubernetes job.\n %s" % str(e) ) @@ -348,8 +368,8 @@ class RunningJob(object): JOB_ACTIVE = "job:active" JOB_FAILED = "" - def __init__(self, client, name, uid, namespace): - self._client = client + def __init__(self, client_wrapper, name, uid, namespace): + self._client_wrapper = client_wrapper self._name = name self._id = uid self._namespace = namespace @@ -367,31 +387,35 @@ def __repr__(self): ) def _fetch_job(self): + client = self._client_wrapper.get() try: return ( - self._client.BatchV1Api() + client.BatchV1Api() .read_namespaced_job(name=self._name, namespace=self._namespace) .to_dict() ) - except self._client.rest.ApiException as e: + except client.rest.ApiException as e: # TODO: Handle failures as well as the fact that a different # process can delete the job. raise e def _fetch_pod(self): + client = self._client_wrapper.get() try: - # TODO (savin): pods may not appear immediately or they may - # disappear - return ( - self._client.CoreV1Api() - .list_namespaced_pod( - namespace=self._namespace, - label_selector="job-name={}".format(self._name), - ) - .to_dict()["items"] - or [None] - )[0] - except self._client.rest.ApiException as e: + # Pod objects may not get created immediately after job submission + for _ in range(POD_FETCH_RETRIES): + pods = client.CoreV1Api().list_namespaced_pod( + namespace=self._namespace, + label_selector="job-name={}".format(self._name), + ).to_dict()["items"] + + if pods: + return pods[0] + else: + time.sleep(POD_FETCH_BACKOFF_SECONDS) + else: + raise Exception('Could not fetch pod status in %s seconds' % (POD_FETCH_RETRIES * POD_FETCH_BACKOFF_SECONDS)) + except client.rest.ApiException as e: # TODO: Handle failures raise e @@ -418,12 +442,12 @@ def kill(self): # terminate the pod without deleting the object. # 3. If the pod object hasn't shown up yet, we set the parallelism to 0 # to preempt it. - if not self.is_done: - if self.is_running: + client = self._client_wrapper.get() + if not self.check_is_done(): + if self.check_pod_is_running(): # Case 1. from kubernetes.stream import stream - - api_instance = self._client.CoreV1Api + api_instance = client.CoreV1Api try: # TODO (savin): stream opens a web-socket connection. It may # not be desirable to open multiple web-socket @@ -454,7 +478,7 @@ def kill(self): try: # TODO (savin): Also patch job annotation to reflect this # action. - self._client.BatchV1Api().patch_namespaced_job( + client.BatchV1Api().patch_namespaced_job( name=self._name, namespace=self._namespace, field_manager="metaflow", @@ -472,8 +496,7 @@ def id(self): # TODO (savin): Should we use pod id instead? return self._id - @property - def is_done(self): + def check_is_done(self): def _done(): # Either the job succeeds or fails naturally or we may have # forced the pod termination causing the job to still be in an @@ -494,10 +517,20 @@ def _done(): self._job = self._fetch_job() return _done() - @property - def status(self): - if not self.is_done: - # If not done, check for newer status + + def wait_done(self, timeout_seconds): + deadline = time.time() + timeout_seconds + while time.time() < deadline: + if self.check_is_done(): + return True + else: + time.sleep(1) + raise TimeoutError("Timed out while waiting for Job to become done") + + + def get_status(self): + if not self.check_is_done(): + # If not done, check for newer pod status self._pod = self._fetch_pod() # Success! if bool(self._job["status"].get("succeeded")): @@ -534,38 +567,33 @@ def status(self): return msg return "Job:Unknown" - @property - def has_succeeded(self): + def check_has_succeeded(self): # Job is in a terminal state and the status is marked as succeeded - return self.is_done and bool(self._job["status"].get("succeeded")) + return self.check_is_done() and bool(self._job["status"].get("succeeded")) - @property - def has_failed(self): + def check_has_failed(self): # Job is in a terminal state and either the status is marked as failed # or the Job is not allowed to launch any more pods - return self.is_done and ( + return self.check_is_done() and ( bool(self._job["status"].get("failed")) or (self._job["spec"]["parallelism"] == 0) ) - @property - def is_running(self): - # The container is running. This happens when the Pod's phase is running - if not self.is_done: + def check_pod_is_running(self): + # Returns true if the container is running. There are two situations + # where is_running may return False, either: + # - the job is done + # - the container hasn't started *yet* + if not self.check_is_done(): # If not done, check if pod has been assigned and is in Running # phase self._pod = self._fetch_pod() return self._pod.get("status", {}).get("phase") == "Running" return False - @property - def is_waiting(self): - return not self.is_done and not self.is_running - - @property - def reason(self): - if self.is_done: - if self.has_succeeded: + def get_done_reason(self): + if self.check_is_done(): + if self.check_has_succeeded(): return 0, None # Best effort since Pod object can disappear on us at anytime else: