diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index 35a56e231e6bd..9238b5f6daa7c 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -22,11 +22,14 @@ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union, overload +from kubernetes.client import models as k8s + from airflow.decorators.base import FParams, FReturn, Task, TaskDecorator from airflow.decorators.branch_python import branch_task from airflow.decorators.python import python_task from airflow.decorators.python_virtualenv import virtualenv_task from airflow.decorators.task_group import task_group +from airflow.kubernetes.secret import Secret from airflow.models.dag import dag # Please keep this in sync with __init__.py's __all__. @@ -239,5 +242,120 @@ class TaskDecoratorCollection: :param cap_add: Include container capabilities """ # [END decorator_signature] + def kubernetes( + self, + *, + image: str, + kubernetes_conn_id: str = ..., + namespace: str = "default", + name: str = ..., + random_name_suffix: bool = True, + ports: Optional[List[k8s.V1ContainerPort]] = None, + volume_mounts: Optional[List[k8s.V1VolumeMount]] = None, + volumes: Optional[List[k8s.V1Volume]] = None, + env_vars: Optional[List[k8s.V1EnvVar]] = None, + env_from: Optional[List[k8s.V1EnvFromSource]] = None, + secrets: Optional[List[Secret]] = None, + in_cluster: Optional[bool] = None, + cluster_context: Optional[str] = None, + labels: Optional[Dict] = None, + reattach_on_restart: bool = True, + startup_timeout_seconds: int = 120, + get_logs: bool = True, + image_pull_policy: Optional[str] = None, + annotations: Optional[Dict] = None, + container_resources: Optional[k8s.V1ResourceRequirements] = None, + affinity: Optional[k8s.V1Affinity] = None, + config_file: str = ..., + node_selector: Optional[dict] = None, + image_pull_secrets: Optional[List[k8s.V1LocalObjectReference]] = None, + service_account_name: Optional[str] = None, + is_delete_operator_pod: bool = True, + hostnetwork: bool = False, + tolerations: Optional[List[k8s.V1Toleration]] = None, + security_context: Optional[Dict] = None, + dnspolicy: Optional[str] = None, + schedulername: Optional[str] = None, + init_containers: Optional[List[k8s.V1Container]] = None, + log_events_on_failure: bool = False, + do_xcom_push: bool = False, + pod_template_file: Optional[str] = None, + priority_class_name: Optional[str] = None, + pod_runtime_info_envs: Optional[List[k8s.V1EnvVar]] = None, + termination_grace_period: Optional[int] = None, + configmaps: Optional[List[str]] = None, + **kwargs, + ) -> TaskDecorator: + """Create a decorator to convert a callable to a Kubernetes Pod task. + + :param kubernetes_conn_id: The Kubernetes cluster's + :ref:`connection ID `. + :param namespace: Namespace to run within Kubernetes. Defaults to *default*. + :param image: Docker image to launch. Defaults to *hub.docker.com*, but + a fully qualified URL will point to a custom repository. (templated) + :param name: Name of the pod to run. This will be used (plus a random + suffix if *random_name_suffix* is *True*) to generate a pod ID + (DNS-1123 subdomain, containing only ``[a-z0-9.-]``). Defaults to + ``k8s_airflow_pod_{RANDOM_UUID}``. + :param random_name_suffix: If *True*, will generate a random suffix. + :param ports: Ports for the launched pod. + :param volume_mounts: *volumeMounts* for the launched pod. + :param volumes: Volumes for the launched pod. Includes *ConfigMaps* and + *PersistentVolumes*. + :param env_vars: Environment variables initialized in the container. + (templated) + :param env_from: List of sources to populate environment variables in + the container. + :param secrets: Kubernetes secrets to inject in the container. They can + be exposed as environment variables or files in a volume. + :param in_cluster: Run kubernetes client with *in_cluster* configuration. + :param cluster_context: Context that points to the Kubernetes cluster. + Ignored when *in_cluster* is *True*. If *None*, current-context will + be used. + :param reattach_on_restart: If the worker dies while the pod is running, + reattach and monitor during the next try. If *False*, always create + a new pod for each try. + :param labels: Labels to apply to the pod. (templated) + :param startup_timeout_seconds: Timeout in seconds to startup the pod. + :param get_logs: Get the stdout of the container as logs of the tasks. + :param image_pull_policy: Specify a policy to cache or always pull an + image. + :param annotations: Non-identifying metadata you can attach to the pod. + Can be a large range of data, and can include characters that are + not permitted by labels. + :param container_resources: Resources for the launched pod. + :param affinity: Affinity scheduling rules for the launched pod. + :param config_file: The path to the Kubernetes config file. If not + specified, default value is ``~/.kube/config``. (templated) + :param node_selector: A dict containing a group of scheduling rules. + :param image_pull_secrets: Any image pull secrets to be given to the + pod. If more than one secret is required, provide a comma separated + list, e.g. ``secret_a,secret_b``. + :param service_account_name: Name of the service account. + :param is_delete_operator_pod: What to do when the pod reaches its final + state, or the execution is interrupted. If *True* (default), delete + the pod; otherwise leave the pod. + :param hostnetwork: If *True*, enable host networking on the pod. + :param tolerations: A list of Kubernetes tolerations. + :param security_context: Security options the pod should run with + (PodSecurityContext). + :param dnspolicy: DNS policy for the pod. + :param schedulername: Specify a scheduler name for the pod + :param init_containers: Init containers for the launched pod. + :param log_events_on_failure: Log the pod's events if a failure occurs. + :param do_xcom_push: If *True*, the content of + ``/airflow/xcom/return.json`` in the container will also be pushed + to an XCom when the container completes. + :param pod_template_file: Path to pod template file (templated) + :param priority_class_name: Priority class name for the launched pod. + :param pod_runtime_info_envs: A list of environment variables + to be set in the container. + :param termination_grace_period: Termination grace period if task killed + in UI, defaults to kubernetes default + :param configmaps: A list of names of config maps from which it collects + ConfigMaps to populate the environment variables with. The contents + of the target ConfigMap's Data field will represent the key-value + pairs as environment variables. Extends env_from. + """ task: TaskDecoratorCollection diff --git a/airflow/providers/cncf/kubernetes/decorators/__init__.py b/airflow/providers/cncf/kubernetes/decorators/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/airflow/providers/cncf/kubernetes/decorators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py new file mode 100644 index 0000000000000..9acab6d093504 --- /dev/null +++ b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import inspect +import os +import pickle +import uuid +from tempfile import TemporaryDirectory +from textwrap import dedent +from typing import TYPE_CHECKING, Callable, Optional, Sequence + +from kubernetes.client import models as k8s + +from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory +from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator +from airflow.providers.cncf.kubernetes.python_kubernetes_script import ( + remove_task_decorator, + write_python_script, +) + +if TYPE_CHECKING: + from airflow.utils.context import Context + +_PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT" + +_FILENAME_IN_CONTAINER = "/tmp/script.py" + + +def _generate_decode_command() -> str: + return ( + f'python -c "import base64, os;' + rf'x = os.environ[\"{_PYTHON_SCRIPT_ENV}\"];' + rf'f = open(\"{_FILENAME_IN_CONTAINER}\", \"w\"); f.write(x); f.close()"' + ) + + +def _read_file_contents(filename): + with open(filename) as script_file: + return script_file.read() + + +class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator): + custom_operator_name = "@task.kubernetes" + + template_fields: Sequence[str] = ('op_args', 'op_kwargs') + + # since we won't mutate the arguments, we should just do the shallow copy + # there are some cases we can't deepcopy the objects (e.g protobuf). + shallow_copy_attrs: Sequence[str] = ('python_callable',) + + def __init__(self, namespace: str = "default", **kwargs) -> None: + self.pickling_library = pickle + super().__init__( + namespace=namespace, + name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"), + cmds=["bash"], + arguments=["-cx", f"{_generate_decode_command()} && python {_FILENAME_IN_CONTAINER}"], + **kwargs, + ) + + def _get_python_source(self): + raw_source = inspect.getsource(self.python_callable) + res = dedent(raw_source) + res = remove_task_decorator(res, "@task.kubernetes") + return res + + def execute(self, context: 'Context'): + with TemporaryDirectory(prefix="venv") as tmp_dir: + script_filename = os.path.join(tmp_dir, 'script.py') + py_source = self._get_python_source() + + jinja_context = { + "op_args": self.op_args, + "op_kwargs": self.op_kwargs, + "pickling_library": self.pickling_library.__name__, + "python_callable": self.python_callable.__name__, + "python_callable_source": py_source, + "string_args_global": False, + } + write_python_script(jinja_context=jinja_context, filename=script_filename) + + self.env_vars = [ + *self.env_vars, + k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)), + ] + return super().execute(context) + + +def kubernetes_task( + python_callable: Optional[Callable] = None, + multiple_outputs: Optional[bool] = None, + **kwargs, +) -> TaskDecorator: + """Kubernetes operator decorator. + + This wraps a function to be executed in K8s using KubernetesPodOperator. + Also accepts any argument that DockerOperator will via ``kwargs``. Can be + reused in a single DAG. + + :param python_callable: Function to decorate + :param multiple_outputs: if set, function return value will be + unrolled to multiple XCom values. Dict will unroll to xcom values with + keys as XCom keys. Defaults to False. + """ + return task_decorator_factory( + python_callable=python_callable, + multiple_outputs=multiple_outputs, + decorated_operator_class=_KubernetesDecoratedOperator, + **kwargs, + ) diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index c4d91be10ddc1..1bc956c4914f4 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -147,7 +147,6 @@ class KubernetesPodOperator(BaseOperator): to populate the environment variables with. The contents of the target ConfigMap's Data field will represent the key-value pairs as environment variables. Extends env_from. - :param: kubernetes_conn_id: To retrieve credentials for your k8s cluster from an Airflow connection """ BASE_CONTAINER_NAME = 'base' diff --git a/airflow/providers/cncf/kubernetes/provider.yaml b/airflow/providers/cncf/kubernetes/provider.yaml index f86bfdda3224a..60c028d75cfe6 100644 --- a/airflow/providers/cncf/kubernetes/provider.yaml +++ b/airflow/providers/cncf/kubernetes/provider.yaml @@ -91,3 +91,7 @@ hooks: connection-types: - hook-class-name: airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook connection-type: kubernetes + +task-decorators: + - class-name: airflow.providers.cncf.kubernetes.decorators.kubernetes.kubernetes_task + name: kubernetes diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 new file mode 100644 index 0000000000000..c961f10de4e5c --- /dev/null +++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 @@ -0,0 +1,44 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +-#} + +import {{ pickling_library }} +import sys + +{# Check whether Airflow is available in the environment. + # If it is, we'll want to ensure that we integrate any macros that are being provided + # by plugins prior to unpickling the task context. #} +if sys.version_info >= (3,6): + try: + from airflow.plugins_manager import integrate_macros_plugins + integrate_macros_plugins() + except ImportError: + {# Airflow is not available in this environment, therefore we won't + # be able to integrate any plugin macros. #} + pass + +{% if op_args or op_kwargs %} +with open(sys.argv[1], "rb") as file: + arg_dict = {{ pickling_library }}.load(file) +{% else %} +arg_dict = {"args": [], "kwargs": {}} +{% endif %} + +# Script +{{ python_callable_source }} +res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"]) diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.py b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py new file mode 100644 index 0000000000000..a13eec6ff372f --- /dev/null +++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Utilities for using the kubernetes decorator""" +import os +from collections import deque + +import jinja2 + + +def _balance_parens(after_decorator): + num_paren = 1 + after_decorator = deque(after_decorator) + after_decorator.popleft() + while num_paren: + current = after_decorator.popleft() + if current == "(": + num_paren = num_paren + 1 + elif current == ")": + num_paren = num_paren - 1 + return ''.join(after_decorator) + + +def remove_task_decorator(python_source: str, task_decorator_name: str) -> str: + """ + Removed @kubernetes_task + + :param python_source: + """ + if task_decorator_name not in python_source: + return python_source + split = python_source.split(task_decorator_name) + before_decorator, after_decorator = split[0], split[1] + if after_decorator[0] == "(": + after_decorator = _balance_parens(after_decorator) + if after_decorator[0] == "\n": + after_decorator = after_decorator[1:] + return before_decorator + after_decorator + + +def write_python_script( + jinja_context: dict, + filename: str, + render_template_as_native_obj: bool = False, +): + """ + Renders the python script to a file to execute in the virtual environment. + + :param jinja_context: The jinja context variables to unpack and replace with its placeholders in the + template file. + :param filename: The name of the file to dump the rendered script to. + :param render_template_as_native_obj: If ``True``, rendered Jinja template would be converted + to a native Python object + """ + template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__)) + template_env: jinja2.Environment + if render_template_as_native_obj: + template_env = jinja2.nativetypes.NativeEnvironment( + loader=template_loader, undefined=jinja2.StrictUndefined + ) + else: + template_env = jinja2.Environment(loader=template_loader, undefined=jinja2.StrictUndefined) + template = template_env.get_template('python_kubernetes_script.jinja2') + template.stream(**jinja_context).dump(filename) diff --git a/tests/providers/cncf/kubernetes/decorators/__init__.py b/tests/providers/cncf/kubernetes/decorators/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/cncf/kubernetes/decorators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py new file mode 100644 index 0000000000000..3df6bf030ea63 --- /dev/null +++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import mock + +import pytest + +from airflow.decorators import task +from airflow.utils import timezone + +DEFAULT_DATE = timezone.datetime(2021, 9, 1) + +KPO_MODULE = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod" +POD_MANAGER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager" +HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook" + + +@pytest.fixture(autouse=True) +def mock_create_pod() -> mock.Mock: + return mock.patch(f"{POD_MANAGER_CLASS}.create_pod").start() + + +@pytest.fixture(autouse=True) +def mock_await_pod_start() -> mock.Mock: + return mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start() + + +@pytest.fixture(autouse=True) +def mock_await_pod_completion() -> mock.Mock: + f = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion").start() + f.return_value = mock.MagicMock(**{"status.phase": "Succeeded"}) + return f + + +@pytest.fixture(autouse=True) +def mock_hook(): + return mock.patch(HOOK_CLASS).start() + + +def test_basic_kubernetes(dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock) -> None: + with dag_maker(session=session) as dag: + + @task.kubernetes( + image="python:3.10-slim-buster", + in_cluster=False, + cluster_context="default", + config_file="/tmp/fake_file", + ) + def f(): + import random + + return [random.random() for _ in range(100)] + + f() + + dr = dag_maker.create_dagrun() + (ti,) = dr.task_instances + dag.get_task("f").execute(context=ti.get_template_context(session=session)) + mock_hook.assert_called_once_with( + conn_id=None, + in_cluster=False, + cluster_context="default", + config_file="/tmp/fake_file", + ) + assert mock_create_pod.call_count == 1 + + containers = mock_create_pod.call_args[1]["pod"].spec.containers + assert len(containers) == 1 + assert containers[0].command == ["bash"] + + assert len(containers[0].args) == 2 + assert containers[0].args[0] == "-cx" + assert containers[0].args[1].endswith("/tmp/script.py") + + assert containers[0].env[-1].name == "__PYTHON_SCRIPT" + assert containers[0].env[-1].value diff --git a/tests/system/providers/cncf/kubernetes/example_kubernetes_decorator.py b/tests/system/providers/cncf/kubernetes/example_kubernetes_decorator.py new file mode 100644 index 0000000000000..a2488a493fa73 --- /dev/null +++ b/tests/system/providers/cncf/kubernetes/example_kubernetes_decorator.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime + +from airflow import DAG +from airflow.decorators import task + +with DAG( + dag_id="example_kubernetes_decorator", + schedule=None, + start_date=datetime(2021, 1, 1), + tags=["example", "cncf", "kubernetes"], + catchup=False, +) as dag: + + @task.kubernetes( + image="python:3.8-slim-buster", + name="k8s_test", + namespace="default", + in_cluster=False, + config_file="/path/to/.kube/config", + ) + def execute_in_k8s_pod(): + import time + + print("Hello from k8s pod") + time.sleep(2) + + @task.kubernetes(image="python:3.8-slim-buster", namespace="default", in_cluster=False) + def print_pattern(): + n = 5 + for i in range(0, n): + # inner loop to handle number of columns + # values changing acc. to outer loop + for j in range(0, i + 1): + # printing stars + print("* ", end="") + + # ending line after each row + print("\r") + + execute_in_k8s_pod_instance = execute_in_k8s_pod() + print_pattern_instance = print_pattern() + + execute_in_k8s_pod_instance >> print_pattern_instance + + +from tests.system.utils import get_test_run + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)