From b467e61bdf25f31213d58265f6f2aa4de67dc755 Mon Sep 17 00:00:00 2001 From: James Coder Date: Sat, 6 May 2023 11:58:07 -0400 Subject: [PATCH] fix kubernetes task decorator pickle error squash --- .../cncf/kubernetes/decorators/kubernetes.py | 7 +++--- .../kubernetes/decorators/test_kubernetes.py | 24 +++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py index af3416b30fe69..337a54797db4f 100644 --- a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py @@ -69,7 +69,7 @@ class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator): shallow_copy_attrs: Sequence[str] = ("python_callable",) def __init__(self, namespace: str = "default", use_dill: bool = False, **kwargs) -> None: - self.pickling_library = dill if use_dill else pickle + self.use_dill = use_dill super().__init__( namespace=namespace, name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"), @@ -112,17 +112,18 @@ def _generate_cmds(self) -> list[str]: def execute(self, context: Context): with TemporaryDirectory(prefix="venv") as tmp_dir: + pickling_library = dill if self.use_dill else pickle script_filename = os.path.join(tmp_dir, "script.py") input_filename = os.path.join(tmp_dir, "script.in") with open(input_filename, "wb") as file: - self.pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file) + pickling_library.dump({"args": self.op_args, "kwargs": self.op_kwargs}, file) py_source = self.get_python_source() jinja_context = { "op_args": self.op_args, "op_kwargs": self.op_kwargs, - "pickling_library": self.pickling_library.__name__, + "pickling_library": pickling_library.__name__, "python_callable": self.python_callable.__name__, "python_callable_source": py_source, "string_args_global": False, diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py index 8493ba067fb47..0cc9a72c12afb 100644 --- a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py @@ -208,3 +208,27 @@ def f(): assert len(dag.task_group.children) == 1 teardown_task = dag.task_group.children["f"] assert teardown_task._is_teardown + + +def test_kubernetes_with_mini_scheduler( + dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock +) -> None: + with dag_maker(session=session): + + @task.kubernetes( + image="python:3.10-slim-buster", + in_cluster=False, + cluster_context="default", + config_file="/tmp/fake_file", + ) + def f(arg1, arg2, kwarg1=None, kwarg2=None): + return {"key1": "value1", "key2": "value2"} + + f1 = f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1") + f.override(task_id="my_task_id2", do_xcom_push=False)("arg1", "arg2", kwarg1=f1) + + dr = dag_maker.create_dagrun() + (ti, _) = dr.task_instances + + # check that mini-scheduler works + ti.schedule_downstream_tasks()