diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index 5c62fc43d01..00352a69b88 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -88,11 +88,13 @@ def get_plugin_cli(): # Add new CLI commands in this list from . import package_cli from .aws.batch import batch_cli + from .aws.eks import kubernetes_cli from .aws.step_functions import step_functions_cli return _ext_plugins.get_plugin_cli() + [ package_cli.cli, batch_cli.cli, + kubernetes_cli.cli, step_functions_cli.cli] @@ -113,6 +115,7 @@ def _merge_lists(base, overrides, attr): from .retry_decorator import RetryDecorator from .resources_decorator import ResourcesDecorator from .aws.batch.batch_decorator import BatchDecorator +from .aws.eks.kubernetes_decorator import KubernetesDecorator from .aws.step_functions.step_functions_decorator \ import StepFunctionsInternalDecorator from .test_unbounded_foreach_decorator\ @@ -125,6 +128,7 @@ def _merge_lists(base, overrides, attr): ResourcesDecorator, RetryDecorator, BatchDecorator, + KubernetesDecorator, StepFunctionsInternalDecorator, CondaStepDecorator, InternalTestUnboundedForeachDecorator], diff --git a/metaflow/plugins/aws/batch/batch.py b/metaflow/plugins/aws/batch/batch.py index 4784857a61f..07eac64e259 100644 --- a/metaflow/plugins/aws/batch/batch.py +++ b/metaflow/plugins/aws/batch/batch.py @@ -1,40 +1,46 @@ -import os -import time +import atexit import json +import os import select -import atexit import shlex import time -import warnings -from metaflow.exception import MetaflowException, MetaflowInternalError -from metaflow.metaflow_config import BATCH_METADATA_SERVICE_URL, DATATOOLS_S3ROOT, \ - DATASTORE_LOCAL_DIR, DATASTORE_SYSROOT_S3, DEFAULT_METADATA, \ - BATCH_METADATA_SERVICE_HEADERS, BATCH_EMIT_TAGS from metaflow import util - -from .batch_client import BatchClient - from metaflow.datatools.s3tail import S3Tail +from metaflow.exception import MetaflowException, MetaflowInternalError +from metaflow.metaflow_config import ( + BATCH_METADATA_SERVICE_URL, + DATATOOLS_S3ROOT, + DATASTORE_LOCAL_DIR, + DATASTORE_SYSROOT_S3, + DEFAULT_METADATA, + BATCH_METADATA_SERVICE_HEADERS, + BATCH_EMIT_TAGS +) from metaflow.mflog.mflog import refine, set_should_persist -from metaflow.mflog import export_mflog_env_vars,\ - bash_capture_logs,\ - update_delay,\ - BASH_SAVE_LOGS +from metaflow.mflog import ( + export_mflog_env_vars, + bash_capture_logs, + update_delay, + BASH_SAVE_LOGS, +) + +from .batch_client import BatchClient # Redirect structured logs to /logs/ -LOGS_DIR = '/logs' -STDOUT_FILE = 'mflog_stdout' -STDERR_FILE = 'mflog_stderr' +LOGS_DIR = "/logs" +STDOUT_FILE = "mflog_stdout" +STDERR_FILE = "mflog_stderr" STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE) STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE) + class BatchException(MetaflowException): - headline = 'AWS Batch error' + headline = "AWS Batch error" class BatchKilledException(MetaflowException): - headline = 'AWS Batch task killed' + headline = "AWS Batch task killed" class Batch(object): @@ -42,22 +48,24 @@ def __init__(self, metadata, environment): self.metadata = metadata self.environment = environment self._client = BatchClient() - atexit.register(lambda: self.job.kill() if hasattr(self, 'job') else None) + atexit.register( + lambda: self.job.kill() if hasattr(self, "job") else None + ) - def _command(self, - environment, - code_package_url, - step_name, - step_cmds, - task_spec): - mflog_expr = export_mflog_env_vars(datastore_type='s3', - stdout_path=STDOUT_PATH, - stderr_path=STDERR_PATH, - **task_spec) + def _command( + self, environment, code_package_url, step_name, step_cmds, task_spec + ): + mflog_expr = export_mflog_env_vars( + datastore_type="s3", + stdout_path=STDOUT_PATH, + stderr_path=STDERR_PATH, + **task_spec + ) init_cmds = environment.get_package_commands(code_package_url) - init_expr = ' && '.join(init_cmds) - step_expr = bash_capture_logs(' && '.join( - environment.bootstrap_commands(step_name) + step_cmds)) + init_expr = " && ".join(init_cmds) + step_expr = bash_capture_logs( + " && ".join(environment.bootstrap_commands(step_name) + step_cmds) + ) # construct an entry point that # 1) initializes the mflog environment (mflog_expr) @@ -67,47 +75,52 @@ def _command(self, # the `true` command is to make sure that the generated command # plays well with docker containers which have entrypoint set as # eval $@ - cmd_str = 'true && mkdir -p /logs && %s && %s && %s; ' % \ - (mflog_expr, init_expr, step_expr) + cmd_str = "true && mkdir -p /logs && %s && %s && %s; " % ( + mflog_expr, + init_expr, + step_expr, + ) # after the task has finished, we save its exit code (fail/success) # and persist the final logs. The whole entrypoint should exit # with the exit code (c) of the task. # # Note that if step_expr OOMs, this tail expression is never executed. - # We lose the last logs in this scenario (although they are visible + # We lose the last logs in this scenario (although they are visible # still through AWS CloudWatch console). - cmd_str += 'c=$?; %s; exit $c' % BASH_SAVE_LOGS - return shlex.split('bash -c \"%s\"' % cmd_str) + cmd_str += "c=$?; %s; exit $c" % BASH_SAVE_LOGS + return shlex.split('bash -c "%s"' % cmd_str) def _search_jobs(self, flow_name, run_id, user): if user is None: - regex = '-{flow_name}-'.format(flow_name=flow_name) + regex = "-{flow_name}-".format(flow_name=flow_name) else: - regex = '{user}-{flow_name}-'.format( - user=user, flow_name=flow_name - ) + regex = "{user}-{flow_name}-".format(user=user, flow_name=flow_name) jobs = [] for job in self._client.unfinished_jobs(): - if regex in job['jobName']: - jobs.append(job['jobId']) + if regex in job["jobName"]: + jobs.append(job["jobId"]) if run_id is not None: - run_id = run_id[run_id.startswith('sfn-') and len('sfn-'):] + run_id = run_id[run_id.startswith("sfn-") and len("sfn-") :] for job in self._client.describe_jobs(jobs): - parameters = job['parameters'] - match = (user is None or parameters['metaflow.user'] == user) and \ - (parameters['metaflow.flow_name'] == flow_name) and \ - (run_id is None or parameters['metaflow.run_id'] == run_id) + parameters = job["parameters"] + match = ( + (user is None or parameters["metaflow.user"] == user) + and (parameters["metaflow.flow_name"] == flow_name) + and (run_id is None or parameters["metaflow.run_id"] == run_id) + ) if match: yield job - def _job_name(self, user, flow_name, run_id, step_name, task_id, retry_count): - return '{user}-{flow_name}-{run_id}-{step_name}-{task_id}-{retry_count}'.format( + def _job_name( + self, user, flow_name, run_id, step_name, task_id, retry_count + ): + return "{user}-{flow_name}-{run_id}-{step_name}-{task_id}-{retry_count}".format( user=user, flow_name=flow_name, - run_id=str(run_id) if run_id is not None else '', + run_id=str(run_id) if run_id is not None else "", step_name=step_name, - task_id=str(task_id) if task_id is not None else '', - retry_count=str(retry_count) if retry_count is not None else '' + task_id=str(task_id) if task_id is not None else "", + retry_count=str(retry_count) if retry_count is not None else "", ) def list_jobs(self, flow_name, run_id, user, echo): @@ -116,12 +129,12 @@ def list_jobs(self, flow_name, run_id, user, echo): for job in jobs: found = True echo( - '{name} [{id}] ({status})'.format( - name=job['jobName'], id=job['jobId'], status=job['status'] + "{name} [{id}] ({status})".format( + name=job["jobName"], id=job["jobId"], status=job["status"] ) ) if not found: - echo('No running AWS Batch jobs found.') + echo("No running AWS Batch jobs found.") def kill_jobs(self, flow_name, run_id, user, echo): jobs = self._search_jobs(flow_name, run_id, user) @@ -129,19 +142,21 @@ def kill_jobs(self, flow_name, run_id, user, echo): for job in jobs: found = True try: - self._client.attach_job(job['jobId']).kill() + self._client.attach_job(job["jobId"]).kill() echo( - 'Killing AWS Batch job: {name} [{id}] ({status})'.format( - name=job['jobName'], id=job['jobId'], status=job['status'] + "Killing AWS Batch job: {name} [{id}] ({status})".format( + name=job["jobName"], + id=job["jobId"], + status=job["status"], ) ) except Exception as e: echo( - 'Failed to terminate AWS Batch job %s [%s]' - % (job['jobId'], repr(e)) + "Failed to terminate AWS Batch job %s [%s]" + % (job["jobId"], repr(e)) ) if not found: - echo('No running AWS Batch jobs found.') + echo("No running AWS Batch jobs found.") def create_job( self, @@ -167,17 +182,17 @@ def create_job( host_volumes=None, ): job_name = self._job_name( - attrs.get('metaflow.user'), - attrs.get('metaflow.flow_name'), - attrs.get('metaflow.run_id'), - attrs.get('metaflow.step_name'), - attrs.get('metaflow.task_id'), - attrs.get('metaflow.retry_count') + attrs.get("metaflow.user"), + attrs.get("metaflow.flow_name"), + attrs.get("metaflow.run_id"), + attrs.get("metaflow.step_name"), + attrs.get("metaflow.task_id"), + attrs.get("metaflow.retry_count"), ) - job = self._client.job() - job \ - .job_name(job_name) \ - .job_queue(queue) \ + job = ( + self._client.job() + .job_name(job_name) + .job_queue(queue) .command( self._command(self.environment, code_package_url, step_name, [step_cli], task_spec)) \ @@ -204,7 +219,7 @@ def create_job( .environment_variable('METAFLOW_DATASTORE_SYSROOT_S3', DATASTORE_SYSROOT_S3) \ .environment_variable('METAFLOW_DATATOOLS_S3ROOT', DATATOOLS_S3ROOT) \ .environment_variable('METAFLOW_DEFAULT_DATASTORE', 's3') \ - .environment_variable('METAFLOW_DEFAULT_METADATA', DEFAULT_METADATA) + .environment_variable('METAFLOW_DEFAULT_METADATA', DEFAULT_METADATA)) # Skip setting METAFLOW_DATASTORE_SYSROOT_LOCAL because metadata sync between the local user # instance and the remote AWS Batch instance assumes metadata is stored in DATASTORE_LOCAL_DIR # on the remote AWS Batch instance; this happens when METAFLOW_DATASTORE_SYSROOT_LOCAL @@ -235,7 +250,7 @@ def launch_job( image, queue, iam_role=None, - execution_role=None, # for FARGATE compatibility + execution_role=None, # for FARGATE compatibility cpu=None, gpu=None, memory=None, @@ -246,13 +261,13 @@ def launch_job( host_volumes=None, env={}, attrs={}, - ): + ): if queue is None: queue = next(self._client.active_job_queues(), None) if queue is None: raise BatchException( - 'Unable to launch AWS Batch job. No job queue ' - ' specified and no valid & enabled queue found.' + "Unable to launch AWS Batch job. No job queue " + " specified and no valid & enabled queue found." ) job = self.create_job( step_name, @@ -279,28 +294,29 @@ def launch_job( self.job = job.execute() def wait(self, stdout_location, stderr_location, echo=None): - def wait_for_launch(job): status = job.status - echo('Task is starting (status %s)...' % status, - 'stderr', - batch_id=job.id) + echo( + "Task is starting (status %s)..." % status, + "stderr", + batch_id=job.id, + ) t = time.time() while True: - if status != job.status or (time.time()-t) > 30: + if status != job.status or (time.time() - t) > 30: status = job.status echo( - 'Task is starting (status %s)...' % status, - 'stderr', - batch_id=job.id + "Task is starting (status %s)..." % status, + "stderr", + batch_id=job.id, ) t = time.time() if job.is_running or job.is_done or job.is_crashed: break select.poll().poll(200) - prefix = b'[%s] ' % util.to_bytes(self.job.id) - + prefix = b"[%s] " % util.to_bytes(self.job.id) + def _print_available(tail, stream, should_persist=False): # print the latest batch of lines from S3Tail try: @@ -309,11 +325,14 @@ def _print_available(tail, stream, should_persist=False): line = set_should_persist(line) else: line = refine(line, prefix=prefix) - echo(line.strip().decode('utf-8', errors='replace'), stream) + echo(line.strip().decode("utf-8", errors="replace"), stream) except Exception as ex: - echo('[ temporary error in fetching logs: %s ]' % ex, - 'stderr', - batch_id=self.job.id) + echo( + "[ temporary error in fetching logs: %s ]" % ex, + "stderr", + batch_id=self.job.id, + ) + stdout_tail = S3Tail(stdout_location) stderr_tail = S3Tail(stderr_location) @@ -328,8 +347,8 @@ def _print_available(tail, stream, should_persist=False): while is_running: if time.time() > next_log_update: - _print_available(stdout_tail, 'stdout') - _print_available(stderr_tail, 'stderr') + _print_available(stdout_tail, "stdout") + _print_available(stderr_tail, "stderr") now = time.time() log_update_delay = update_delay(now - start_time) next_log_update = now + log_update_delay @@ -340,7 +359,7 @@ def _print_available(tail, stream, should_persist=False): # a long delay, regardless of the log tailing schedule d = min(log_update_delay, 5.0) select.poll().poll(d * 1000) - + # 3) Fetch remaining logs # # It is possible that we exit the loop above before all logs have been @@ -349,29 +368,33 @@ def _print_available(tail, stream, should_persist=False): # TODO if we notice AWS Batch failing to upload logs to S3, we can add a # HEAD request here to ensure that the file exists prior to calling # S3Tail and note the user about truncated logs if it doesn't - _print_available(stdout_tail, 'stdout') - _print_available(stderr_tail, 'stderr') + _print_available(stdout_tail, "stdout") + _print_available(stderr_tail, "stderr") # In case of hard crashes (OOM), the final save_logs won't happen. - # We fetch the remaining logs from AWS CloudWatch and persist them to + # We fetch the remaining logs from AWS CloudWatch and persist them to # Amazon S3. - # - # TODO: AWS CloudWatch fetch logs if self.job.is_crashed: - msg = next(msg for msg in - [self.job.reason, self.job.status_reason, 'Task crashed.'] - if msg is not None) + msg = next( + msg + for msg in [ + self.job.reason, + self.job.status_reason, + "Task crashed.", + ] + if msg is not None + ) raise BatchException( - '%s ' - 'This could be a transient error. ' - 'Use @retry to retry.' % msg + "%s " + "This could be a transient error. " + "Use @retry to retry." % msg ) else: if self.job.is_running: # Kill the job if it is still running by throwing an exception. raise BatchException("Task failed!") echo( - 'Task finished with exit code %s.' % self.job.status_code, - 'stderr', - batch_id=self.job.id + "Task finished with exit code %s." % self.job.status_code, + "stderr", + batch_id=self.job.id, ) diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 89d8ec9beea..b5640874a0f 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -43,18 +43,28 @@ def _execute_cmd(func, flow_name, run_id, user, my_runs, echo): if not run_id and latest_run: run_id = util.get_latest_run_id(echo, flow_name) if run_id is None: - raise CommandException("A previous run id was not found. Specify --run-id.") + raise CommandException( + "A previous run id was not found. Specify --run-id." + ) func(flow_name, run_id, user, echo) @batch.command(help="List unfinished AWS Batch tasks of this flow") -@click.option("--my-runs", default=False, is_flag=True, - help="List all my unfinished tasks.") -@click.option("--user", default=None, - help="List unfinished tasks for the given user.") -@click.option("--run-id", default=None, - help="List unfinished tasks corresponding to the run id.") +@click.option( + "--my-runs", + default=False, + is_flag=True, + help="List all my unfinished tasks.", +) +@click.option( + "--user", default=None, help="List unfinished tasks for the given user." +) +@click.option( + "--run-id", + default=None, + help="List unfinished tasks corresponding to the run id.", +) @click.pass_context def list(ctx, run_id, user, my_runs): batch = Batch(ctx.obj.metadata, ctx.obj.environment) @@ -64,12 +74,22 @@ def list(ctx, run_id, user, my_runs): @batch.command(help="Terminate unfinished AWS Batch tasks of this flow.") -@click.option("--my-runs", default=False, is_flag=True, - help="Kill all my unfinished tasks.") -@click.option("--user", default=None, - help="Terminate unfinished tasks for the given user.") -@click.option("--run-id", default=None, - help="Terminate unfinished tasks corresponding to the run id.") +@click.option( + "--my-runs", + default=False, + is_flag=True, + help="Kill all my unfinished tasks.", +) +@click.option( + "--user", + default=None, + help="Terminate unfinished tasks for the given user.", +) +@click.option( + "--run-id", + default=None, + help="Terminate unfinished tasks corresponding to the run id.", +) @click.pass_context def kill(ctx, run_id, user, my_runs): batch = Batch(ctx.obj.metadata, ctx.obj.environment) @@ -79,24 +99,23 @@ def kill(ctx, run_id, user, my_runs): @batch.command( - help="Execute a single task using AWS Batch. This command " - "calls the top-level step command inside a AWS Batch " - "job with the given options. Typically you do not " - "call this command directly; it is used internally " - "by Metaflow." + help="Execute a single task using AWS Batch. This command calls the " + "top-level step command inside a AWS Batch job with the given options. " + "Typically you do not call this command directly; it is used internally by " + "Metaflow." ) @click.argument("step-name") @click.argument("code-package-sha") @click.argument("code-package-url") @click.option("--executable", help="Executable requirement for AWS Batch.") @click.option( - "--image", help="Docker image requirement for AWS Batch. In name:version format." -) -@click.option( - "--iam-role", help="IAM role requirement for AWS Batch." + "--image", + help="Docker image requirement for AWS Batch. In name:version format.", ) +@click.option("--iam-role", help="IAM role requirement for AWS Batch.") @click.option( - "--execution-role", help="Execution role requirement for AWS Batch on Fargate." + "--execution-role", + help="Execution role requirement for AWS Batch on Fargate.", ) @click.option("--cpu", help="CPU requirement for AWS Batch.") @click.option("--gpu", help="GPU requirement for AWS Batch.") @@ -111,17 +130,23 @@ def kill(ctx, run_id, user, my_runs): @click.option( "--tag", multiple=True, default=None, help="Passed to the top-level 'step'." ) -@click.option("--namespace", default=None, help="Passed to the top-level 'step'.") -@click.option("--retry-count", default=0, help="Passed to the top-level 'step'.") +@click.option( + "--namespace", default=None, help="Passed to the top-level 'step'." +) +@click.option( + "--retry-count", default=0, help="Passed to the top-level 'step'." +) @click.option( "--max-user-code-retries", default=0, help="Passed to the top-level 'step'." ) @click.option( "--run-time-limit", default=5 * 24 * 60 * 60, - help="Run time limit in seconds for the AWS Batch job. " "Default is 5 days." + help="Run time limit in seconds for the AWS Batch job. Default is 5 days.", +) +@click.option( + "--shared-memory", help="Shared Memory requirement for AWS Batch." ) -@click.option("--shared-memory", help="Shared Memory requirement for AWS Batch.") @click.option("--max-swap", help="Max Swap requirement for AWS Batch.") @click.option("--swappiness", help="Swappiness requirement for AWS Batch.") #TODO: Maybe remove it altogether since it's not used here @@ -148,10 +173,10 @@ def step( host_volumes=None, **kwargs ): - def echo(msg, stream='stderr', batch_id=None): + def echo(msg, stream="stderr", batch_id=None): msg = util.to_unicode(msg) if batch_id: - msg = '[%s] %s' % (batch_id, msg) + msg = "[%s] %s" % (batch_id, msg) ctx.obj.echo_always(msg, err=(stream == sys.stderr)) if R.use_r(): @@ -159,8 +184,7 @@ def echo(msg, stream='stderr', batch_id=None): else: if executable is None: executable = ctx.obj.environment.executable(step_name) - entrypoint = '%s -u %s' % (executable, - os.path.basename(sys.argv[0])) + entrypoint = "%s -u %s" % (executable, os.path.basename(sys.argv[0])) top_args = " ".join(util.dict_to_cli_options(ctx.parent.parent.params)) @@ -169,14 +193,18 @@ def echo(msg, stream='stderr', batch_id=None): if input_paths: max_size = 30 * 1024 split_vars = { - "METAFLOW_INPUT_PATHS_%d" % (i // max_size): input_paths[i : i + max_size] + "METAFLOW_INPUT_PATHS_%d" + % (i // max_size): input_paths[i : i + max_size] for i in range(0, len(input_paths), max_size) } kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys()) step_args = " ".join(util.dict_to_cli_options(kwargs)) step_cli = u"{entrypoint} {top_args} step {step} {step_args}".format( - entrypoint=entrypoint, top_args=top_args, step=step_name, step_args=step_args + entrypoint=entrypoint, + top_args=top_args, + step=step_name, + step_args=step_args, ) node = ctx.obj.graph[step_name] @@ -191,17 +219,17 @@ def echo(msg, stream='stderr', batch_id=None): # Set batch attributes task_spec = { - 'flow_name': ctx.obj.flow.name, - 'step_name': step_name, - 'run_id': kwargs['run_id'], - 'task_id': kwargs['task_id'], - 'retry_count': str(retry_count) + "flow_name": ctx.obj.flow.name, + "step_name": step_name, + "run_id": kwargs["run_id"], + "task_id": kwargs["task_id"], + "retry_count": str(retry_count), } - attrs = {'metaflow.%s' % k: v for k, v in task_spec.items()} - attrs['metaflow.user'] = util.get_username() - attrs['metaflow.version'] = ctx.obj.environment.get_environment_info()[ - "metaflow_version" - ] + attrs = {"metaflow.%s" % k: v for k, v in task_spec.items()} + attrs["metaflow.user"] = util.get_username() + attrs["metaflow.version"] = ctx.obj.environment.get_environment_info()[ + "metaflow_version" + ] env_deco = [deco for deco in node.decorators if deco.name == "environment"] if env_deco: @@ -215,7 +243,8 @@ def echo(msg, stream='stderr', batch_id=None): if retry_count: ctx.obj.echo_always( - "Sleeping %d minutes before the next AWS Batch retry" % minutes_between_retries + "Sleeping %d minutes before the next AWS Batch retry" + % minutes_between_retries ) time.sleep(minutes_between_retries * 60) @@ -264,7 +293,7 @@ def _sync_metadata(): host_volumes=host_volumes, ) except Exception as e: - print(e) + traceback.print_exc() _sync_metadata() sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) try: diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index c3b3e97e479..64e988d5c2e 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -485,64 +485,8 @@ def status_code(self): self.update() return self.info['container'].get('exitCode') - def wait_for_running(self): - if not self.is_running and not self.is_done: - BatchWaiter(self._client).wait_for_running(self.id) - def kill(self): if not self.is_done: self._client.terminate_job( jobId=self._id, reason='Metaflow initiated job termination.') return self.update() - - -class BatchWaiter(object): - def __init__(self, client): - try: - from botocore import waiter - except: - raise BatchJobException( - 'Could not import module \'botocore\' which ' - 'is required for Batch jobs. Install botocore ' - 'first.' - ) - self._client = client - self._waiter = waiter - - def wait_for_running(self, job_id): - model = self._waiter.WaiterModel( - { - 'version': 2, - 'waiters': { - 'JobRunning': { - 'delay': 1, - 'operation': 'DescribeJobs', - 'description': 'Wait until job starts running', - 'maxAttempts': 1000000, - 'acceptors': [ - { - 'argument': 'jobs[].status', - 'expected': 'SUCCEEDED', - 'matcher': 'pathAll', - 'state': 'success', - }, - { - 'argument': 'jobs[].status', - 'expected': 'FAILED', - 'matcher': 'pathAny', - 'state': 'success', - }, - { - 'argument': 'jobs[].status', - 'expected': 'RUNNING', - 'matcher': 'pathAny', - 'state': 'success', - }, - ], - } - }, - } - ) - self._waiter.create_waiter_with_client('JobRunning', model, self._client).wait( - jobs=[job_id] - ) \ No newline at end of file diff --git a/metaflow/plugins/aws/eks/__init__.py b/metaflow/plugins/aws/eks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/plugins/aws/eks/kubernetes.py b/metaflow/plugins/aws/eks/kubernetes.py new file mode 100644 index 00000000000..6592e6aa177 --- /dev/null +++ b/metaflow/plugins/aws/eks/kubernetes.py @@ -0,0 +1,405 @@ +import os +import time +import json +import select +import shlex +import time +import re +import hashlib + +from metaflow import util +from metaflow.datatools.s3tail import S3Tail +from metaflow.exception import MetaflowException, MetaflowInternalError +from metaflow.metaflow_config import ( + BATCH_METADATA_SERVICE_URL, + DATATOOLS_S3ROOT, + DATASTORE_LOCAL_DIR, + DATASTORE_SYSROOT_S3, + DEFAULT_METADATA, + BATCH_METADATA_SERVICE_HEADERS, +) +from metaflow.mflog import ( + export_mflog_env_vars, + bash_capture_logs, + update_delay, + BASH_SAVE_LOGS +) +from metaflow.mflog.mflog import refine, set_should_persist + +from .kubernetes_client import KubernetesClient + +# Redirect structured logs to /logs/ +LOGS_DIR = "/logs" +STDOUT_FILE = "mflog_stdout" +STDERR_FILE = "mflog_stderr" +STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE) +STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE) + + +class KubernetesException(MetaflowException): + headline = "Kubernetes error" + + +class KubernetesKilledException(MetaflowException): + headline = "Kubernetes Batch job killed" + + +def generate_rfc1123_name(flow_name, + run_id, + step_name, + task_id, + attempt +): + """ + Generate RFC 1123 compatible name. Specifically, the format is: + [*[]] + + The generated name consists from a human-readable prefix, derived from + flow/step/task/attempt, and a hash suffux. + """ + long_name = "-".join( + [ + flow_name, + run_id, + step_name, + task_id, + attempt, + ] + ) + hash = hashlib.sha256(long_name.encode('utf-8')).hexdigest() + + if long_name.startswith('_'): + # RFC 1123 names can't start with hyphen so slap an extra prefix on it + sanitized_long_name = 'u' + long_name.replace('_', '-').lower() + else: + sanitized_long_name = long_name.replace('_', '-').lower() + + # the name has to be under 63 chars total + return sanitized_long_name[:57] + '-' + hash[:5] + + +LABEL_VALUE_REGEX = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9\-\_\.]{0,61}[a-zA-Z0-9])?$') + + +def sanitize_label_value(val): + # Label sanitization: if the value can be used as is, return it as is. + # If it can't, sanitize and add a suffix based on hash of the original + # value, replace invalid chars and truncate. + # + # The idea here is that even if there are non-allowed chars in the same + # position, this function will likely return distinct values, so you can + # still filter on those. For example, "alice$" and "alice&" will be + # sanitized into different values "alice_b3f201" and "alice_2a6f13". + if val == '' or LABEL_VALUE_REGEX.match(val): + return val + hash = hashlib.sha256(val.encode('utf-8')).hexdigest() + + # Replace invalid chars with dots, and if the first char is + # non-alphahanumeric, replace it with 'u' to make it valid + sanitized_val = re.sub('^[^A-Z0-9a-z]', 'u', re.sub(r"[^A-Za-z0-9.\-_]", "_", val)) + return sanitized_val[:57] + '-' + hash[:5] + + +class Kubernetes(object): + def __init__( + self, + datastore, + metadata, + environment, + flow_name, + run_id, + step_name, + task_id, + attempt, + ): + self._datastore = datastore + self._metadata = metadata + self._environment = environment + + self._flow_name = flow_name + self._run_id = run_id + self._step_name = step_name + self._task_id = task_id + self._attempt = str(attempt) + + def _command( + self, + code_package_url, + step_cmds, + ): + mflog_expr = export_mflog_env_vars( + flow_name=self._flow_name, + run_id=self._run_id, + step_name=self._step_name, + task_id=self._task_id, + retry_count=self._attempt, + datastore_type=self._datastore.TYPE, + stdout_path=STDOUT_PATH, + stderr_path=STDERR_PATH, + ) + init_cmds = self._environment.get_package_commands(code_package_url) + init_expr = " && ".join(init_cmds) + step_expr = bash_capture_logs( + " && ".join( + self._environment.bootstrap_commands(self._step_name) + + step_cmds + ) + ) + + # Construct an entry point that + # 1) initializes the mflog environment (mflog_expr) + # 2) bootstraps a metaflow environment (init_expr) + # 3) executes a task (step_expr) + + # The `true` command is to make sure that the generated command + # plays well with docker containers which have entrypoint set as + # eval $@ + cmd_str = "true && mkdir -p /logs && %s && %s && %s; " % ( + mflog_expr, + init_expr, + step_expr, + ) + # After the task has finished, we save its exit code (fail/success) + # and persist the final logs. The whole entrypoint should exit + # with the exit code (c) of the task. + # + # Note that if step_expr OOMs, this tail expression is never executed. + # We lose the last logs in this scenario. + # + # TODO: Find a way to capture hard exit logs in Kubernetes. + cmd_str += "c=$?; %s; exit $c" % BASH_SAVE_LOGS + return shlex.split('bash -c "%s"' % cmd_str) + + def launch_job(self, **kwargs): + self._job = self.create_job(**kwargs).execute() + + def create_job( + self, + user, + code_package_sha, + code_package_url, + code_package_ds, + step_cli, + docker_image, + service_account=None, + secrets=None, + node_selector=None, + namespace=None, + cpu=None, + gpu=None, + disk=None, + memory=None, + run_time_limit=None, + env={}, + ): + # TODO: Test for DNS-1123 compliance. Python names can have underscores + # which are not valid Kubernetes names. We can potentially make + # the pathspec DNS-1123 compliant by stripping away underscores + # etc. and relying on Kubernetes to attach a suffix to make the + # name unique within a namespace. + # + # Set the pathspec (along with attempt) as the Kubernetes job name. + # Kubernetes job names are supposed to be unique within a Kubernetes + # namespace and compliant with DNS-1123. The pathspec (with attempt) + # can provide that guarantee, however, for flows launched via AWS Step + # Functions (and potentially Argo), we may not get the task_id or the + # attempt_id while submitting the job to the Kubernetes cluster. If + # that is indeed the case, we can rely on Kubernetes to generate a name + # for us. + job_name = generate_rfc1123_name( + self._flow_name, + self._run_id, + self._step_name, + self._task_id, + self._attempt, + ) + + job = ( + KubernetesClient() + .job( + name=job_name, + namespace=namespace, + service_account=service_account, + secrets=secrets, + node_selector=node_selector, + command=self._command( + code_package_url=code_package_url, + step_cmds=[step_cli], + ), + image=docker_image, + cpu=cpu, + memory=memory, + disk=disk, + timeout_in_seconds=run_time_limit, + # Retries are handled by Metaflow runtime + retries=0, + ) + .environment_variable( + # This is needed since `boto3` is not smart enough to figure out + # AWS region by itself. + # TODO: Fix this. + "AWS_DEFAULT_REGION", + "us-west-2", + ) + .environment_variable("METAFLOW_CODE_SHA", code_package_sha) + .environment_variable("METAFLOW_CODE_URL", code_package_url) + .environment_variable("METAFLOW_CODE_DS", code_package_ds) + .environment_variable("METAFLOW_USER", user) + .environment_variable( + "METAFLOW_SERVICE_URL", BATCH_METADATA_SERVICE_URL + ) + .environment_variable( + "METAFLOW_SERVICE_HEADERS", + json.dumps(BATCH_METADATA_SERVICE_HEADERS), + ) + .environment_variable( + "METAFLOW_DATASTORE_SYSROOT_S3", DATASTORE_SYSROOT_S3 + ) + .environment_variable("METAFLOW_DATATOOLS_S3ROOT", DATATOOLS_S3ROOT) + .environment_variable("METAFLOW_DEFAULT_DATASTORE", "s3") + .environment_variable("METAFLOW_DEFAULT_METADATA", DEFAULT_METADATA) + .environment_variable("METAFLOW_KUBERNETES_WORKLOAD", 1) + .label("app", "metaflow") + .label("metaflow/flow_name", sanitize_label_value(self._flow_name)) + .label("metaflow/run_id", sanitize_label_value(self._run_id)) + .label("metaflow/step_name", sanitize_label_value(self._step_name)) + .label("metaflow/task_id", sanitize_label_value(self._task_id)) + .label("metaflow/attempt", sanitize_label_value(self._attempt)) + ) + + # Skip setting METAFLOW_DATASTORE_SYSROOT_LOCAL because metadata sync + # between the local user instance and the remote Kubernetes pod + # assumes metadata is stored in DATASTORE_LOCAL_DIR on the Kubernetes + # pod; this happens when METAFLOW_DATASTORE_SYSROOT_LOCAL is NOT set ( + # see get_datastore_root_from_config in datastore/local.py). + for name, value in env.items(): + job.environment_variable(name, value) + + # Add labels to the Kubernetes job + # + # Apply recommended labels https://kubernetes.io/docs/concepts/overview/working-with-objects/common-labels/ + # + # TODO: 1. Verify the behavior of high cardinality labels like instance, + # version etc. in the app.kubernetes.io namespace before + # introducing them here. + job.label("app.kubernetes.io/name", "metaflow-task").label( + "app.kubernetes.io/part-of", "metaflow" + ).label("app.kubernetes.io/created-by", sanitize_label_value(user)) + # Add Metaflow system tags as labels as well! + for sys_tag in self._metadata.sticky_sys_tags: + job.label( + "metaflow/%s" % sys_tag[: sys_tag.index(":")], + sanitize_label_value(sys_tag[sys_tag.index(":") + 1 :]) + ) + # TODO: Add annotations based on https://kubernetes.io/blog/2021/04/20/annotating-k8s-for-humans/ + + return job.create() + + def wait(self, stdout_location, stderr_location, echo=None): + + def wait_for_launch(job): + status = job.status + echo( + "Task is starting (Status %s)..." % status, + "stderr", + job_id=job.id, + ) + t = time.time() + while True: + new_status = job.status + if status != new_status or (time.time() - t) > 30: + status = new_status + echo( + "Task is starting (Status %s)..." % status, + "stderr", + job_id=job.id, + ) + t = time.time() + if job.is_running or job.is_done: + break + time.sleep(1) + + def _print_available(tail, stream, should_persist=False): + # print the latest batch of lines from S3Tail + prefix = b"[%s] " % util.to_bytes(self._job.id) + try: + for line in tail: + if should_persist: + line = set_should_persist(line) + else: + line = refine(line, prefix=prefix) + echo(line.strip().decode("utf-8", errors="replace"), stream) + except Exception as ex: + echo( + "[ temporary error in fetching logs: %s ]" % ex, + "stderr", + job_id=self._job.id, + ) + + stdout_tail = S3Tail(stdout_location) + stderr_tail = S3Tail(stderr_location) + + # 1) Loop until the job has started + wait_for_launch(self._job) + + # 2) Loop until the job has finished + start_time = time.time() + is_running = True + next_log_update = start_time + log_update_delay = 1 + + while is_running: + if time.time() > next_log_update: + _print_available(stdout_tail, "stdout") + _print_available(stderr_tail, "stderr") + now = time.time() + log_update_delay = update_delay(now - start_time) + next_log_update = now + log_update_delay + is_running = self._job.is_running + + # This sleep should never delay log updates. On the other hand, + # we should exit this loop when the task has finished without + # a long delay, regardless of the log tailing schedule + time.sleep(min(log_update_delay, 5.0)) + + # 3) Fetch remaining logs + # + # It is possible that we exit the loop above before all logs have been + # shown. + # + # TODO (savin): If we notice Kubernetes failing to upload logs to S3, + # we can add a HEAD request here to ensure that the file + # exists prior to calling S3Tail and note the user about + # truncated logs if it doesn't. + # TODO (savin): For hard crashes, we can fetch logs from the pod. + _print_available(stdout_tail, "stdout") + _print_available(stderr_tail, "stderr") + + if self._job.has_failed: + exit_code, reason = self._job.reason + msg = next( + msg + for msg in [ + reason, + "Task crashed", + ] + if msg is not None + ) + if exit_code: + if int(exit_code) == 139: + raise KubernetesException( + "Task failed with a segmentation fault." + ) + else: + msg = "%s (exit code %s)" % (msg, exit_code) + raise KubernetesException( + "%s. This could be a transient error. " + "Use @retry to retry." % msg + ) + + exit_code, _ = self._job.reason + echo( + "Task finished with exit code %s." % exit_code, + "stderr", + job_id=self._job.id, + ) diff --git a/metaflow/plugins/aws/eks/kubernetes_cli.py b/metaflow/plugins/aws/eks/kubernetes_cli.py new file mode 100644 index 00000000000..32315986b47 --- /dev/null +++ b/metaflow/plugins/aws/eks/kubernetes_cli.py @@ -0,0 +1,234 @@ +import click +import os +import sys +import time +import traceback + +from metaflow import util +from metaflow.exception import CommandException, METAFLOW_EXIT_DISALLOW_RETRY +from metaflow.metadata.util import sync_local_metadata_from_datastore +from metaflow.metaflow_config import DATASTORE_LOCAL_DIR +from metaflow.mflog import TASK_LOG_SOURCE + +from .kubernetes import Kubernetes, KubernetesKilledException + +# TODO(s): +# 1. Compatibility for Metaflow-R (not a blocker for release). +# 2. Add more CLI commands to manage Kubernetes objects. + + +@click.group() +def cli(): + pass + + +@cli.group(help="Commands related to Kubernetes on Amazon EKS.") +def kubernetes(): + pass + + +@kubernetes.command( + help="Execute a single task on Kubernetes using Amazon EKS. This command " + "calls the top-level step command inside a Kubernetes job with the given " + "options. Typically you do not call this command directly; it is used " + "internally by Metaflow." +) +@click.argument("step-name") +@click.argument("code-package-sha") +@click.argument("code-package-url") +@click.option( + "--executable", + help="Executable requirement for Kubernetes job on Amazon EKS.", +) +@click.option( + "--image", help="Docker image requirement for Kubernetes job on Amazon EKS." +) +@click.option( + "--service-account", + help="IRSA requirement for Kubernetes job on Amazon EKS.", +) +@click.option( + "--secrets", + multiple=True, + default=None, + help="Secrets for Kubernetes job on Amazon EKS.", +) +@click.option( + "--node-selector", + multiple=True, + default=None, + help="NodeSelector for Kubernetes job on Amazon EKS.", +) +@click.option( + # Note that ideally we would have liked to use `namespace` rather than + # `k8s-namespace` but unfortunately, `namespace` is already reserved for + # Metaflow namespaces. + "--k8s-namespace", + default=None, + help="Namespace for Kubernetes job on Amazon EKS.", +) +@click.option("--cpu", help="CPU requirement for Kubernetes job on Amazon EKS.") +@click.option("--gpu", help="GPU requirement for Kubernetes job on Amazon EKS.") +@click.option( + "--disk", help="Disk requirement for Kubernetes job on Amazon EKS." +) +@click.option( + "--memory", help="Memory requirement for Kubernetes job on Amazon EKS." +) +@click.option("--run-id", help="Passed to the top-level 'step'.") +@click.option("--task-id", help="Passed to the top-level 'step'.") +@click.option("--input-paths", help="Passed to the top-level 'step'.") +@click.option("--split-index", help="Passed to the top-level 'step'.") +@click.option("--clone-path", help="Passed to the top-level 'step'.") +@click.option("--clone-run-id", help="Passed to the top-level 'step'.") +@click.option( + "--tag", multiple=True, default=None, help="Passed to the top-level 'step'." +) +@click.option( + "--namespace", default=None, help="Passed to the top-level 'step'." +) +@click.option( + "--retry-count", default=0, help="Passed to the top-level 'step'." +) +@click.option( + "--max-user-code-retries", default=0, help="Passed to the top-level 'step'." +) +@click.option( + "--run-time-limit", + default=5 * 24 * 60 * 60, # Default is set to 5 days + help="Run time limit in seconds for Kubernetes job.", +) +@click.pass_context +def step( + ctx, + step_name, + code_package_sha, + code_package_url, + executable=None, + image=None, + service_account=None, + secrets=None, + node_selector=None, + k8s_namespace=None, + cpu=None, + gpu=None, + disk=None, + memory=None, + run_time_limit=None, + **kwargs +): + def echo(msg, stream="stderr", job_id=None): + msg = util.to_unicode(msg) + if job_id: + msg = "[%s] %s" % (job_id, msg) + ctx.obj.echo_always(msg, err=(stream == sys.stderr)) + + node = ctx.obj.graph[step_name] + + # Construct entrypoint CLI + if executable is None: + executable = ctx.obj.environment.executable(step_name) + + # Set environment + env = {} + env_deco = [deco for deco in node.decorators if deco.name == "environment"] + if env_deco: + env = env_deco[0].attributes["vars"] + + # Set input paths. + input_paths = kwargs.get("input_paths") + split_vars = None + if input_paths: + max_size = 30 * 1024 + split_vars = { + "METAFLOW_INPUT_PATHS_%d" + % (i // max_size): input_paths[i : i + max_size] + for i in range(0, len(input_paths), max_size) + } + kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys()) + env.update(split_vars) + + # Set retry policy. + retry_count = int(kwargs.get("retry_count", 0)) + retry_deco = [deco for deco in node.decorators if deco.name == "retry"] + minutes_between_retries = None + if retry_deco: + minutes_between_retries = int( + retry_deco[0].attributes.get("minutes_between_retries", 2) + ) + if retry_count: + ctx.obj.echo_always( + "Sleeping %d minutes before the next retry" + % minutes_between_retries + ) + time.sleep(minutes_between_retries * 60) + + step_cli = u"{entrypoint} {top_args} step {step} {step_args}".format( + entrypoint="%s -u %s" % (executable, os.path.basename(sys.argv[0])), + top_args=" ".join(util.dict_to_cli_options(ctx.parent.parent.params)), + step=step_name, + step_args=" ".join(util.dict_to_cli_options(kwargs)), + ) + + # this information is needed for log tailing + ds = ctx.obj.flow_datastore.get_task_datastore( + mode='w', + run_id=kwargs['run_id'], + step_name=step_name, + task_id=kwargs['task_id'], + attempt=int(retry_count) + ) + stdout_location = ds.get_log_location(TASK_LOG_SOURCE, 'stdout') + stderr_location = ds.get_log_location(TASK_LOG_SOURCE, 'stderr') + + def _sync_metadata(): + if ctx.obj.metadata.TYPE == 'local': + sync_local_metadata_from_datastore( + DATASTORE_LOCAL_DIR, + ctx.obj.flow_datastore.get_task_datastore(kwargs['run_id'], + step_name, + kwargs['task_id'])) + + try: + kubernetes = Kubernetes( + datastore=ctx.obj.flow_datastore, + metadata=ctx.obj.metadata, + environment=ctx.obj.environment, + flow_name=ctx.obj.flow.name, + run_id=kwargs["run_id"], + step_name=step_name, + task_id=kwargs["task_id"], + attempt=retry_count, + ) + # Configure and launch Kubernetes job. + with ctx.obj.monitor.measure("metaflow.aws.eks.launch_job"): + kubernetes.launch_job( + user=util.get_username(), + code_package_sha=code_package_sha, + code_package_url=code_package_url, + code_package_ds=ctx.obj.flow_datastore.TYPE, + step_cli=step_cli, + docker_image=image, + service_account=service_account, + secrets=secrets, + node_selector=node_selector, + namespace=k8s_namespace, + cpu=cpu, + gpu=gpu, + disk=disk, + memory=memory, + run_time_limit=run_time_limit, + env=env, + ) + except Exception as e: + traceback.print_exc() + _sync_metadata() + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + try: + kubernetes.wait(stdout_location, stderr_location, echo=echo) + except KubernetesKilledException: + # don't retry killed tasks + traceback.print_exc() + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + finally: + _sync_metadata() \ No newline at end of file diff --git a/metaflow/plugins/aws/eks/kubernetes_client.py b/metaflow/plugins/aws/eks/kubernetes_client.py new file mode 100644 index 00000000000..b0c143328c0 --- /dev/null +++ b/metaflow/plugins/aws/eks/kubernetes_client.py @@ -0,0 +1,716 @@ +import os +import time +import math +import random + + +try: + unicode +except NameError: + unicode = str + basestring = str + +from metaflow.exception import MetaflowException + +CLIENT_REFRESH_INTERVAL_SECONDS = 300 + + +class KubernetesJobException(MetaflowException): + headline = "Kubernetes job error" + + +# Implements truncated exponential backoff from https://cloud.google.com/storage/docs/retry-strategy#exponential-backoff +def k8s_retry(deadline_seconds=60, max_backoff=32): + def decorator(function): + from functools import wraps + + @wraps(function) + def wrapper(*args, **kwargs): + from kubernetes import client + + deadline = time.time() + deadline_seconds + retry_number = 0 + + while True: + try: + result = function(*args, **kwargs) + return result + except client.rest.ApiException as e: + if e.status == 500: + current_t = time.time() + backoff_delay = min(math.pow(2, retry_number) + random.random(), max_backoff) + if current_t + backoff_delay < deadline: + time.sleep(backoff_delay) + retry_number += 1 + continue # retry again + else: + raise + else: + raise + + return wrapper + return decorator + + +class KubernetesClient(object): + def __init__(self): + # TODO: Look into removing the usage of Kubernetes Python SDK + # at some point in the future. Given that Kubernetes Python SDK + # aggressively drops support for older kubernetes clusters, continued + # dependency on it may bite our users. + + try: + # Kubernetes is a soft dependency. + from kubernetes import client, config + except (NameError, ImportError): + raise MetaflowException( + "Could not import module 'kubernetes'. Install kubernetes " + "Python package (https://pypi.org/project/kubernetes/) first." + ) + self._refresh_client() + + def _refresh_client(self): + from kubernetes import client, config + + if os.getenv("KUBERNETES_SERVICE_HOST"): + # We are inside a pod, authenticate via ServiceAccount assigned to us + config.load_incluster_config() + else: + # Use kubeconfig, likely $HOME/.kube/config + # TODO (savin): + # 1. Support generating kubeconfig on the fly using boto3 + # 2. Support auth via OIDC - https://docs.aws.amazon.com/eks/latest/userguide/authenticate-oidc-identity-provider.html + # Supporting the above auth mechanisms (atleast 1.) should be + # good enough for the initial rollout. + config.load_kube_config() + self._client = client + self._client_refresh_timestamp = time.time() + + def job(self, **kwargs): + return KubernetesJob(self, **kwargs) + + def get(self): + if ( + time.time() - self._client_refresh_timestamp + > CLIENT_REFRESH_INTERVAL_SECONDS + ): + self._refresh_client() + + return self._client + + +class KubernetesJob(object): + def __init__(self, client_wrapper, **kwargs): + self._client_wrapper = client_wrapper + self._kwargs = kwargs + + # Kubernetes namespace defaults to `default` + self._kwargs["namespace"] = self._kwargs["namespace"] or "default" + + def create(self): + # Check that job attributes are sensible. + + # CPU value should be greater than 0 + if not ( + isinstance(self._kwargs["cpu"], (int, unicode, basestring, float)) + and float(self._kwargs["cpu"]) > 0 + ): + raise KubernetesJobException( + "Invalid CPU value ({}); it should be greater than 0".format( + self._kwargs["cpu"] + ) + ) + + # Memory value should be greater than 0 + if not ( + isinstance(self._kwargs["memory"], (int, unicode, basestring)) + and int(self._kwargs["memory"]) > 0 + ): + raise KubernetesJobException( + "Invalid memory value ({}); it should be greater than 0".format( + self._kwargs["memory"] + ) + ) + + # Disk value should be greater than 0 + if not ( + isinstance(self._kwargs["disk"], (int, unicode, basestring)) + and int(self._kwargs["disk"]) > 0 + ): + raise KubernetesJobException( + "Invalid disk value ({}); it should be greater than 0".format( + self._kwargs["disk"] + ) + ) + + # TODO(s) (savin) + # 1. Add support for GPUs. + + # A discerning eye would notice and question the choice of using the + # V1Job construct over the V1Pod construct given that we don't rely much + # on any of the V1Job semantics. The major reasons at the moment are - + # 1. It makes the Kubernetes UIs (Octant, Lens) a bit more easy on + # the eyes, although even that can be questioned. + # 2. AWS Step Functions, at the moment (Aug' 21) only supports + # executing Jobs and not Pods as part of it's publicly declared + # API. When we ship the AWS Step Functions integration with EKS, + # it will hopefully lessen our workload. + # + # Note: This implementation ensures that there is only one unique Pod + # (unique UID) per Metaflow task attempt. + client = self._client_wrapper.get() + self._job = client.V1Job( + api_version="batch/v1", + kind="Job", + metadata=client.V1ObjectMeta( + # Annotations are for humans + annotations=self._kwargs.get("annotations", {}), + # While labels are for Kubernetes + labels=self._kwargs.get("labels", {}), + name=self._kwargs["name"], # Unique within the namespace + namespace=self._kwargs["namespace"], # Defaults to `default` + ), + spec=client.V1JobSpec( + # Retries are handled by Metaflow when it is responsible for + # executing the flow. The responsibility is moved to Kubernetes + # when AWS Step Functions / Argo are responsible for the + # execution. + backoff_limit=self._kwargs.get("retries", 0), + completions=1, # A single non-indexed pod job + # TODO (savin): Implement a job clean-up option in the + # kubernetes CLI. + ttl_seconds_after_finished=7 + * 60 + * 60 # Remove job after a week. TODO (savin): Make this + * 24, # configurable + template=client.V1PodTemplateSpec( + metadata=client.V1ObjectMeta( + annotations=self._kwargs.get("annotations", {}), + labels=self._kwargs.get("labels", {}), + name=self._kwargs["name"], + namespace=self._kwargs["namespace"], + ), + spec=client.V1PodSpec( + # Timeout is set on the pod and not the job (important!) + active_deadline_seconds=self._kwargs["timeout_in_seconds"], + # TODO (savin): Enable affinities for GPU scheduling. + # This requires some thought around the + # UX since specifying affinities can get + # complicated quickly. We may well decide + # to move it out of scope for the initial + # roll out. + # affinity=?, + containers=[ + client.V1Container( + command=self._kwargs["command"], + env=[ + client.V1EnvVar(name=k, value=str(v)) + for k, v in self._kwargs.get( + "environment_variables", {} + ).items() + ] + # And some downward API magic. Add (key, value) + # pairs below to make pod metadata available + # within Kubernetes container. + # + # TODO: Figure out a way to make job + # metadata visible within the container + + [ + client.V1EnvVar( + name=k, + value_from=client.V1EnvVarSource( + field_ref=client.V1ObjectFieldSelector( + field_path=str(v) + ) + ), + ) + for k, v in { + "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace", + "METAFLOW_KUBERNETES_POD_NAME": "metadata.name", + "METAFLOW_KUBERNETES_POD_ID": "metadata.uid", + }.items() + ], + env_from=[ + client.V1EnvFromSource( + secret_ref=client.V1SecretEnvSource(name=str(k)) + ) + for k in self._kwargs.get("secrets", []) + ], + image=self._kwargs["image"], + name=self._kwargs["name"], + resources=client.V1ResourceRequirements( + requests={ + "cpu": str(self._kwargs["cpu"]), + "memory": "%sM" % str(self._kwargs["memory"]), + "ephemeral-storage": "%sM" + % str(self._kwargs["disk"]), + } + ), + ) + ], + node_selector={ + # TODO: What should be the format of node selector - + # key:value or key=value? + str(k.split("=", 1)[0]): str(k.split("=", 1)[1]) + for k in self._kwargs.get("node_selector", []) + }, + # TODO (savin): At some point in the very near future, + # support docker access secrets. + # image_pull_secrets=?, + # + # TODO (savin): We should, someday, get into the pod + # priority business + # preemption_policy=?, + # + # A Container in a Pod may fail for a number of + # reasons, such as because the process in it exited + # with a non-zero exit code, or the Container was + # killed due to OOM etc. If this happens, fail the pod + # and let Metaflow handle the retries. + restart_policy="Never", + service_account_name=self._kwargs["service_account"], + # Terminate the container immediately on SIGTERM + termination_grace_period_seconds=0, + # TODO (savin): Enable tolerations for GPU scheduling. + # This requires some thought around the + # UX since specifying tolerations can get + # complicated quickly. + # tolerations=?, + # + # TODO (savin): At some point in the very near future, + # support custom volumes (PVCs/EVCs). + # volumes=?, + # + # TODO (savin): Set termination_message_policy + ), + ), + ), + ) + return self + + def execute(self): + client = self._client_wrapper.get() + try: + # TODO (savin): Make job submission back-pressure aware. Currently + # there doesn't seem to be a kubernetes-native way to + # achieve the guarantees that we are seeking. + # Hopefully, we will be able to get creative soon. + response = ( + client.BatchV1Api() + .create_namespaced_job( + body=self._job, namespace=self._kwargs["namespace"] + ) + .to_dict() + ) + return RunningJob( + client_wrapper=self._client_wrapper, + name=response["metadata"]["name"], + uid=response["metadata"]["uid"], + namespace=response["metadata"]["namespace"], + ) + except client.rest.ApiException as e: + raise KubernetesJobException( + "Unable to launch Kubernetes job.\n %s" % str(e) + ) + + def namespace(self, namespace): + self._kwargs["namespace"] = namespace + return self + + def name(self, name): + self._kwargs["name"] = name + return self + + def command(self, command): + self._kwargs["command"] = command + return self + + def image(self, image): + self._kwargs["image"] = image + return self + + def cpu(self, cpu): + self._kwargs["cpu"] = cpu + return self + + def memory(self, mem): + self._kwargs["memory"] = mem + return self + + def environment_variable(self, name, value): + self._kwargs["environment_variables"] = dict( + self._kwargs.get("environment_variables", {}), **{name: value} + ) + return self + + def label(self, name, value): + self._kwargs["labels"] = dict(self._kwargs.get("labels", {}), **{name: value}) + return self + + def annotation(self, name, value): + self._kwargs["annotations"] = dict( + self._kwargs.get("annotations", {}), **{name: value} + ) + return self + + +class RunningJob(object): + + # State Machine implementation for the lifecycle behavior documented in + # https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/ + # + # This object encapsulates *both* V1Job and V1Pod. It simplifies the status + # to "running" and "done" (failed/succeeded) state. Note that V1Job and V1Pod + # status fields are not guaranteed to be always in sync due to the way job + # controller works. + # + # For example, for a successful job, RunningJob states and their corresponding + # K8S object states look like this: + # + # | V1JobStatus.active | V1JobStatus.succeeded | V1PodStatus.phase | RunningJob.is_running | RunningJob.is_done | + # |--------------------|-----------------------|-------------------|-----------------------|--------------------| + # | 0 | 0 | N/A | False | False | + # | 0 | 0 | Pending | False | False | + # | 1 | 0 | Running | True | False | + # | 1 | 0 | Succeeded | True | True | + # | 0 | 1 | Succeeded | False | True | + + # To ascertain the status of V1Job, we peer into the lifecycle status of + # the pod it is responsible for executing. Unfortunately, the `phase` + # attributes (pending, running, succeeded, failed etc.) only provide + # partial answers and the official API conventions guide suggests that + # it may soon be deprecated (however, not anytime soon - see + # https://github.com/kubernetes/kubernetes/issues/7856). `conditions` otoh + # provide a deeper understanding about the state of the pod; however + # conditions are not state machines and can be oscillating - from the + # offical API conventions guide: + # In general, condition values may change back and forth, but some + # condition transitions may be monotonic, depending on the resource and + # condition type. However, conditions are observations and not, + # themselves, state machines, nor do we define comprehensive state + # machines for objects, nor behaviors associated with state + # transitions. The system is level-based rather than edge-triggered, + # and should assume an Open World. + # In this implementation, we synthesize our notion of "phase" state + # machine from `conditions`, since Kubernetes won't do it for us (for + # many good reasons). + # + # + # + # `conditions` can be of the following types - + # 1. (kubelet) Initialized (always True since we don't rely on init + # containers) + # 2. (kubelet) ContainersReady + # 3. (kubelet) Ready (same as ContainersReady since we don't use + # ReadinessGates - + # https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/status/generate.go) + # 4. (kube-scheduler) PodScheduled + # (https://github.com/kubernetes/kubernetes/blob/master/pkg/scheduler/scheduler.go) + # 5. (kube-scheduler) Unschedulable + # + # WIP... + + JOB_ACTIVE = "job:active" + JOB_FAILED = "" + + def __init__(self, client_wrapper, name, uid, namespace): + self._client_wrapper = client_wrapper + self._name = name + self._id = uid + self._namespace = namespace + + self._job = self._fetch_job() + self._pod = self._fetch_pod() + + import atexit + + atexit.register(self.kill) + + def __repr__(self): + return "{}('{}/{}')".format( + self.__class__.__name__, self._namespace, self._name + ) + + @k8s_retry() + def _fetch_job(self): + client = self._client_wrapper.get() + try: + return ( + client.BatchV1Api() + .read_namespaced_job(name=self._name, namespace=self._namespace) + .to_dict() + ) + except client.rest.ApiException as e: + # TODO: Handle failures as well as the fact that a different + # process can delete the job. + raise e + + @k8s_retry() + def _fetch_pod(self): + """Fetch pod metadata. May return None if pod does not exist.""" + client = self._client_wrapper.get() + + pods = ( + client.CoreV1Api() + .list_namespaced_pod( + namespace=self._namespace, + label_selector="job-name={}".format(self._name), + ) + .to_dict()["items"] + ) + + if pods: + return pods[0] + else: + return None + + def kill(self): + # Terminating a Kubernetes job is a bit tricky. Issuing a + # `BatchV1Api.delete_namespaced_job` will also remove all traces of the + # job object from the Kubernetes API server which may not be desirable. + # This forces us to be a bit creative in terms of how we handle kill: + # + # 1. If the container is alive and kicking inside the pod, we simply + # attach ourselves to the container and issue a kill signal. The + # way we have initialized the Job ensures that the job will cleanly + # terminate. + # 2. In scenarios where either the pod (unschedulable etc.) or the + # container (ImagePullError etc.) hasn't come up yet, we become a + # bit creative by patching the job parallelism to 0. This ensures + # that the underlying node's resources are made available to + # kube-scheduler again. The downside is that the Job wouldn't mark + # itself as done and the pod metadata disappears from the API + # server. There is an open issue in the Kubernetes GH to provide + # better support for job terminations - + # https://github.com/kubernetes/enhancements/issues/2232 but + # meanwhile as a quick follow-up, we should investigate ways to + # terminate the pod without deleting the object. + # 3. If the pod object hasn't shown up yet, we set the parallelism to 0 + # to preempt it. + client = self._client_wrapper.get() + if not self._check_is_done(): + if self._check_is_running(): + + # Unless there is a bug in the code, self._pod cannot be None + # if we're in "running" state. + assert self._pod is not None + + # Case 1. + from kubernetes.stream import stream + + api_instance = client.CoreV1Api + try: + # TODO (savin): stream opens a web-socket connection. It may + # not be desirable to open multiple web-socket + # connections frivolously (think killing a + # workflow during a for-each step). + stream( + api_instance().connect_get_namespaced_pod_exec, + name=self._pod["metadata"]["name"], + namespace=self._namespace, + command=[ + "/bin/sh", + "-c", + "/sbin/killall5", + ], + stderr=True, + stdin=False, + stdout=True, + tty=False, + ) + except: + # Best effort. It's likely that this API call could be + # blocked for the user. + # TODO (savin): Forward the error to the user. + # pass + raise + else: + # Case 2. + try: + # TODO (savin): Also patch job annotation to reflect this + # action. + client.BatchV1Api().patch_namespaced_job( + name=self._name, + namespace=self._namespace, + field_manager="metaflow", + body={"spec": {"parallelism": 0}}, + ) + except: + # Best effort. + # TODO (savin): Forward the error to the user. + # pass + raise + return self + + @property + def id(self): + # TODO (savin): Should we use pod id instead? + return self._id + + def _check_is_done(self): + def _job_done(): + # Either the job succeeds or fails naturally or we may have + # forced the pod termination causing the job to still be in an + # active state but for all intents and purposes dead to us. + + # TODO (savin): check for self._job + return ( + bool(self._job["status"].get("succeeded")) + or bool(self._job["status"].get("failed")) + or (self._job["spec"]["parallelism"] == 0) + ) + + if not _job_done(): + # If not done, check for newer status + self._job = self._fetch_job() + if _job_done(): + return True + else: + # It is possible for the job metadata to not be updated yet, but the + # Pod has already succeeded or failed. + self._pod = self._fetch_pod() + if self._pod and (self._pod["status"]["phase"] in ("Succeeded", "Failed")): + return True + else: + return False + + def _get_status(self): + if not self._check_is_done(): + # If not done, check for newer pod status + self._pod = self._fetch_pod() + # Success! + if bool(self._job["status"].get("succeeded")): + return "Job:Succeeded" + # Failure! + if bool(self._job["status"].get("failed")) or ( + self._job["spec"]["parallelism"] == 0 + ): + return "Job:Failed" + if bool(self._job["status"].get("active")): + msg = "Job:Active" + if self._pod: + msg += " Pod:%s" % self._pod["status"]["phase"].title() + # TODO (savin): parse Pod conditions + container_status = ( + self._pod["status"].get("container_statuses") or [None] + )[0] + if container_status: + # We have a single container inside the pod + status = {"status": "waiting"} + for k, v in container_status["state"].items(): + if v is not None: + status["status"] = k + status.update(v) + msg += " Container:%s" % status["status"].title() + reason = "" + if status.get("reason"): + reason = status["reason"] + if status.get("message"): + reason += ":%s" % status["message"] + if reason: + msg += " [%s]" % reason + # TODO (savin): This message should be shortened before release. + return msg + return "Job:Unknown" + + def _check_has_succeeded(self): + # Job is in a terminal state and the status is marked as succeeded + if self._check_is_done(): + if bool(self._job["status"].get("succeeded")) or ( + self._pod and self._pod["status"]["phase"] == "Succeeded" + ): + return True + else: + return False + else: + return False + + def _check_has_failed(self): + # Job is in a terminal state and either the status is marked as failed + # or the Job is not allowed to launch any more pods + + if self._check_is_done(): + if ( + bool(self._job["status"].get("failed")) + or (self._job["spec"]["parallelism"] == 0) + or (self._pod and self._pod["status"]["phase"] == "Failed") + ): + return True + else: + return False + else: + return False + + def _check_is_running(self): + # Returns true if the container is running. + if not self._check_is_done(): + # If not done, check if pod has been assigned and is in Running + # phase + if self._pod is None: + return False + pod_phase = self._pod.get("status", {}).get("phase") + return pod_phase == "Running" + return False + + def _get_done_reason(self): + if self._check_is_done(): + if self._check_has_succeeded(): + return 0, None + # Best effort since Pod object can disappear on us at anytime + else: + + def _done(): + return self._pod.get("status", {}).get("phase") in ( + "Succeeded", + "Failed", + ) + + if not _done(): + # If pod status is dirty, check for newer status + self._pod = self._fetch_pod() + if self._pod: + pod_status = self._pod["status"] + if pod_status.get("container_statuses") is None: + # We're done, but no container_statuses is set + # This can happen when the pod is evicted + return None, ": ".join( + filter( + None, + [pod_status.get("reason"), pod_status.get("message")], + ) + ) + + for k, v in ( + pod_status.get("container_statuses", [{}])[0] + .get("state", {}) + .items() + ): + if v is not None: + return v.get("exit_code"), ": ".join( + filter( + None, + [v.get("reason"), v.get("message")], + ) + ) + + return None, None + + @property + def is_done(self): + return self._check_is_done() + + @property + def has_failed(self): + return self._check_has_failed() + + @property + def is_running(self): + return self._check_is_running() + + @property + def reason(self): + return self._get_done_reason() + + @property + def status(self): + return self._get_status() diff --git a/metaflow/plugins/aws/eks/kubernetes_decorator.py b/metaflow/plugins/aws/eks/kubernetes_decorator.py new file mode 100644 index 00000000000..46e937d674c --- /dev/null +++ b/metaflow/plugins/aws/eks/kubernetes_decorator.py @@ -0,0 +1,291 @@ +import os +import sys +import platform +import requests + +from metaflow import util +from metaflow.decorators import StepDecorator +from metaflow.metadata import MetaDatum +from metaflow.metadata.util import sync_local_metadata_to_datastore +from metaflow.metaflow_config import ( + ECS_S3_ACCESS_IAM_ROLE, + BATCH_JOB_QUEUE, + BATCH_CONTAINER_IMAGE, + BATCH_CONTAINER_REGISTRY, + ECS_FARGATE_EXECUTION_ROLE, + DATASTORE_LOCAL_DIR, +) +from metaflow.plugins import ResourcesDecorator +from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task +from metaflow.sidecar import SidecarSubProcess + +from .kubernetes import KubernetesException +from ..aws_utils import get_docker_registry + + +class KubernetesDecorator(StepDecorator): + """ + TODO (savin): Update this docstring. + Step decorator to specify that this step should execute on Kubernetes. + + This decorator indicates that your step should execute on Kubernetes. Note + that you can apply this decorator automatically to all steps using the + ```--with kubernetes``` argument when calling run/resume. Step level + decorators within the code are overrides and will force a step to execute + on Kubernetes regardless of the ```--with``` specification. + + To use, annotate your step as follows: + ``` + @kubernetes + @step + def my_step(self): + ... + ``` + Parameters + ---------- + cpu : int + Number of CPUs required for this step. Defaults to 1. If @resources is + also present, the maximum value from all decorators is used + gpu : int + Number of GPUs required for this step. Defaults to 0. If @resources is + also present, the maximum value from all decorators is used + memory : int + Memory size (in MB) required for this step. Defaults to 4096. If + @resources is also present, the maximum value from all decorators is + used + image : string + Docker image to use when launching on Kubernetes. If not specified, a + default docker image mapping to the current version of Python is used + shared_memory : int + The value for the size (in MiB) of the /dev/shm volume for this step. + This parameter maps to the --shm-size option to docker run. + """ + + name = "kubernetes" + defaults = { + "cpu": "1", + "memory": "4096", + "disk": "10240", + "image": None, + "service_account": None, + "secrets": None, # e.g., mysecret + "node_selector": None, # e.g., kubernetes.io/os=linux + "gpu": "0", + # "shared_memory": None, + "namespace": None, + } + package_url = None + package_sha = None + run_time_limit = None + + def __init__(self, attributes=None, statically_defined=False): + super(KubernetesDecorator, self).__init__( + attributes, statically_defined + ) + + # TODO: Unify the logic with AWS Batch + # If no docker image is explicitly specified, impute a default image. + if not self.attributes["image"]: + # If metaflow-config specifies a docker image, just use that. + if BATCH_CONTAINER_IMAGE: + self.attributes["image"] = BATCH_CONTAINER_IMAGE + # If metaflow-config doesn't specify a docker image, assign a + # default docker image. + else: + # Default to vanilla Python image corresponding to major.minor + # version of the Python interpreter launching the flow. + self.attributes["image"] = "python:%s.%s" % ( + platform.python_version_tuple()[0], + platform.python_version_tuple()[1], + ) + # Assign docker registry URL for the image. + if not get_docker_registry(self.attributes["image"]): + if BATCH_CONTAINER_REGISTRY: + self.attributes["image"] = "%s/%s" % ( + BATCH_CONTAINER_REGISTRY.rstrip("/"), + self.attributes["image"], + ) + + # Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png + # to understand where these functions are invoked in the lifecycle of a + # Metaflow flow. + def step_init( + self, flow, graph, step, decos, environment, flow_datastore, logger + ): + # Executing Kubernetes jobs requires a non-local datastore at the + # moment. + # TODO: To support MiniKube we need to enable local datastore execution. + if flow_datastore.TYPE != "s3": + raise KubernetesException( + "The *@kubernetes* decorator requires --datastore=s3 " + "at the moment." + ) + + # Set internal state. + self.logger = logger + self.environment = environment + self.step = step + self.flow_datastore = flow_datastore + for deco in decos: + if isinstance(deco, ResourcesDecorator): + for k, v in deco.attributes.items(): + # We use the larger of @resources and @k8s attributes + # TODO: Fix https://github.com/Netflix/metaflow/issues/467 + my_val = self.attributes.get(k) + if not (my_val is None and v is None): + self.attributes[k] = str( + max(int(my_val or 0), int(v or 0)) + ) + + # Set run time limit for the Kubernetes job. + self.run_time_limit = get_run_time_limit_for_task(decos) + if self.run_time_limit < 60: + raise KubernetesException( + "The timeout for step *{step}* should be " + "at least 60 seconds for execution on " + "Kubernetes.".format(step=step) + ) + + def runtime_init(self, flow, graph, package, run_id): + # Set some more internal state. + self.flow = flow + self.graph = graph + self.package = package + self.run_id = run_id + + def runtime_task_created(self, + task_datastore, + task_id, + split_index, + input_paths, + is_cloned, + ubf_context): + # To execute the Kubernetes job, the job container needs to have + # access to the code package. We store the package in the datastore + # which the pod is able to download as part of it's entrypoint. + if not is_cloned: + self._save_package_once(self.flow_datastore, self.package) + + def runtime_step_cli( + self, cli_args, retry_count, max_user_code_retries, ubf_context + ): + if retry_count <= max_user_code_retries: + # After all attempts to run the user code have failed, we don't need + # to execute on Kubernetes anymore. We can execute possible fallback + # code locally. + cli_args.commands = ["kubernetes", "step"] + cli_args.command_args.append(self.package_sha) + cli_args.command_args.append(self.package_url) + + # --namespace is used to specify Metaflow namespace (different + # concept from k8s namespace). + for k,v in self.attributes.items(): + if k == 'namespace': + cli_args.command_options['k8s_namespace'] = v + else: + cli_args.command_options[k] = v + cli_args.command_options["run-time-limit"] = self.run_time_limit + cli_args.entrypoint[0] = sys.executable + + def task_pre_step(self, + step_name, + task_datastore, + metadata, + run_id, + task_id, + flow, + graph, + retry_count, + max_retries, + ubf_context, + inputs): + self.metadata = metadata + self.task_datastore = task_datastore + + # task_pre_step may run locally if fallback is activated for @catch + # decorator. In that scenario, we skip collecting Kubernetes execution + # metadata. A rudimentary way to detect non-local execution is to + # check for the existence of METAFLOW_KUBERNETES_WORKLOAD environment + # variable. + + if "METAFLOW_KUBERNETES_WORKLOAD" in os.environ: + meta = {} + # TODO: Get kubernetes job id and job name + meta["kubernetes-pod-id"] = os.environ["METAFLOW_KUBERNETES_POD_ID"] + meta["kubernetes-pod-name"] = os.environ[ + "METAFLOW_KUBERNETES_POD_NAME" + ] + meta["kubernetes-pod-namespace"] = os.environ[ + "METAFLOW_KUBERNETES_POD_NAMESPACE" + ] + # meta['kubernetes-job-attempt'] = ? + + entries = [ + MetaDatum(field=k, value=v, type=k, tags=[]) + for k, v in meta.items() + ] + # Register book-keeping metadata for debugging. + metadata.register_metadata(run_id, step_name, task_id, entries) + + # Start MFLog sidecar to collect task logs. + self._save_logs_sidecar = SidecarSubProcess( + "save_logs_periodically" + ) + + def task_post_step(self, + step_name, + flow, + graph, + retry_count, + max_user_code_retries): + # task_post_step may run locally if fallback is activated for @catch + # decorator. + if 'METAFLOW_KUBERNETES_WORKLOAD' in os.environ: + # If `local` metadata is configured, we would need to copy task + # execution metadata from the AWS Batch container to user's + # local file system after the user code has finished execution. + # This happens via datastore as a communication bridge. + if self.metadata.TYPE == 'local': + # Note that the datastore is *always* Amazon S3 (see + # runtime_task_created function). + sync_local_metadata_to_datastore(DATASTORE_LOCAL_DIR, + self.task_datastore) + + def task_exception(self, + exception, + step_name, + flow, + graph, + retry_count, + max_user_code_retries): + # task_exception may run locally if fallback is activated for @catch + # decorator. + if 'METAFLOW_KUBERNETES_WORKLOAD' in os.environ: + # If `local` metadata is configured, we would need to copy task + # execution metadata from the AWS Batch container to user's + # local file system after the user code has finished execution. + # This happens via datastore as a communication bridge. + if self.metadata.TYPE == 'local': + # Note that the datastore is *always* Amazon S3 (see + # runtime_task_created function). + sync_local_metadata_to_datastore(DATASTORE_LOCAL_DIR, + self.task_datastore) + + def task_finished(self, + step_name, + flow, + graph, + is_task_ok, + retry_count, + max_retries): + try: + self._save_logs_sidecar.kill() + except: + # Best effort kill + pass + + @classmethod + def _save_package_once(cls, flow_datastore, package): + if cls.package_url is None: + cls.package_url, cls.package_sha = flow_datastore.save_data( + [package.blob], len_hint=1)[0] \ No newline at end of file diff --git a/metaflow/plugins/conda/conda_step_decorator.py b/metaflow/plugins/conda/conda_step_decorator.py index cbbbeec6bbb..10a9b9894f2 100644 --- a/metaflow/plugins/conda/conda_step_decorator.py +++ b/metaflow/plugins/conda/conda_step_decorator.py @@ -193,13 +193,13 @@ def _disable_safety_checks(self, decos): # a macOS. This is needed because of gotchas around inconsistently # case-(in)sensitive filesystems for macOS and linux. for deco in decos: - if deco.name == 'batch' and platform.system() == 'Darwin': + if deco.name in ('batch', 'kubernetes') and platform.system() == 'Darwin': return True return False def _architecture(self, decos): for deco in decos: - if deco.name == 'batch': + if deco.name in ('batch', 'kubernetes'): # force conda resolution for linux-64 architectures return 'linux-64' bit = '32' @@ -306,7 +306,9 @@ def runtime_step_cli(self, retry_count, max_user_code_retries, ubf_context): - if self.is_enabled(ubf_context) and 'batch' not in cli_args.commands: + no_batch = 'batch' not in cli_args.commands + no_kubernetes = 'kubernetes' not in cli_args.commands + if self.is_enabled(ubf_context) and no_batch and no_kubernetes: python_path = self.metaflow_home if self.addl_paths is not None: addl_paths = os.pathsep.join(self.addl_paths) diff --git a/metaflow/task.py b/metaflow/task.py index b49bf77b0aa..b98a2c8c0ee 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -496,10 +496,6 @@ def run_step(self, output.save_metadata({'task_end': {}}) output.persist(self.flow) - # this writes a success marker indicating that the - # "transaction" is done - output.done() - # final decorator hook: The task results are now # queryable through the client API / datastore for deco in decorators: @@ -510,6 +506,10 @@ def run_step(self, retry_count, max_user_code_retries) + # this writes a success marker indicating that the + # "transaction" is done + output.done() + # terminate side cars logger.terminate() self.metadata.stop_heartbeat() diff --git a/test/core/contexts.json b/test/core/contexts.json index 6d0e79b238c..c4bf3d98cee 100644 --- a/test/core/contexts.json +++ b/test/core/contexts.json @@ -125,6 +125,40 @@ "DetectSegFaultTest", "TimeoutDecoratorTest" ] + }, + { + "name": "python3-k8s", + "disabled": true, + "python": "python3", + "top_options": [ + "--event-logger=nullSidecarLogger", + "--no-pylint", + "--quiet", + "--with=kubernetes:memory=256,disk=1024", + "--datastore=s3" + ], + "env": { + "METAFLOW_USER": "tester", + "METAFLOW_RUN_BOOL_PARAM": "False", + "METAFLOW_RUN_NO_DEFAULT_PARAM": "test_str", + "METAFLOW_DEFAULT_METADATA": "service" + }, + "run_options": [ + "--max-workers", "50", + "--max-num-splits", "10000", + "--tag", "\u523a\u8eab means sashimi", + "--tag", "multiple tags should be ok" + ], + "checks": ["python3-cli", "python3-metadata"], + "disabled_tests": [ + "LargeArtifactTest", + "WideForeachTest", + "TagCatchTest", + "BasicUnboundedForeachTest", + "NestedUnboundedForeachTest", + "DetectSegFaultTest", + "TimeoutDecoratorTest" + ] } ], "checks": { diff --git a/test/unit/test_k8s_job_name_sanitizer.py b/test/unit/test_k8s_job_name_sanitizer.py new file mode 100644 index 00000000000..e019a47bea3 --- /dev/null +++ b/test/unit/test_k8s_job_name_sanitizer.py @@ -0,0 +1,26 @@ +import re +from metaflow.plugins.aws.eks.kubernetes import generate_rfc1123_name + +rfc1123 = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?$') + +def test_job_name_santitizer(): + # Basic name + assert rfc1123.match(generate_rfc1123_name('HelloFlow', '1', 'end', '321', '1')) + + # Step name ends with _ + assert rfc1123.match(generate_rfc1123_name('HelloFlow', '1', '_end', '321', '1')) + + # Step name starts and ends with _ + assert rfc1123.match(generate_rfc1123_name('HelloFlow', '1', '_end_', '321', '1')) + + # Flow name ends with _ + assert rfc1123.match(generate_rfc1123_name('HelloFlow_', '1', 'end', '321', '1')) + + # Same flow name, different case must produce different job names + assert generate_rfc1123_name('Helloflow', '1', 'end', '321', '1') != generate_rfc1123_name('HelloFlow', '1', 'end', '321', '1') + + # Very long step name should be fine + assert rfc1123.match(generate_rfc1123_name('Helloflow', '1', 'end'*50, '321', '1')) + + # Very long run id should be fine too + assert rfc1123.match(generate_rfc1123_name('Helloflow', '1'*100, 'end', '321', '1')) \ No newline at end of file diff --git a/test/unit/test_k8s_label_sanitizer.py b/test/unit/test_k8s_label_sanitizer.py new file mode 100644 index 00000000000..6fcfbd5553f --- /dev/null +++ b/test/unit/test_k8s_label_sanitizer.py @@ -0,0 +1,28 @@ +import re +from metaflow.plugins.aws.eks.kubernetes import sanitize_label_value, LABEL_VALUE_REGEX + + +def test_label_value_santitizer(): + assert LABEL_VALUE_REGEX.match(sanitize_label_value('HelloFlow')) + + # The value is too long + assert LABEL_VALUE_REGEX.match(sanitize_label_value('a' * 1000)) + + # Different long values should still not be equal after sanitization + assert sanitize_label_value('a' * 1000) != sanitize_label_value('a' * 1001) + assert sanitize_label_value('-' * 1000) != sanitize_label_value('-' * 1001) + + # Different long values should still not be equal after sanitization + assert sanitize_label_value('alice!') != sanitize_label_value('alice?') + + # ends with dash + assert LABEL_VALUE_REGEX.match(sanitize_label_value('HelloFlow-')) + + # non-ascii + assert LABEL_VALUE_REGEX.match(sanitize_label_value('метафлоу')) + + # different only in case + assert sanitize_label_value('Alice') != sanitize_label_value('alice') + + # spaces + assert LABEL_VALUE_REGEX.match(sanitize_label_value('Meta flow')) \ No newline at end of file