Skip to content

Commit

Permalink
KubernetesPodOperator new callbacks and allow multiple callbacks (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
johnhoran authored and niklasr22 committed Feb 8, 2025
1 parent d5b840e commit 9351bd0
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 58 deletions.
95 changes: 88 additions & 7 deletions providers/src/airflow/providers/cncf/kubernetes/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
from __future__ import annotations

from enum import Enum
from typing import Union
from typing import TYPE_CHECKING, Union

import kubernetes.client as k8s
import kubernetes_asyncio.client as async_k8s

if TYPE_CHECKING:
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.utils.context import Context

client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api]


Expand All @@ -41,7 +45,7 @@ class KubernetesPodOperatorCallback:
"""

@staticmethod
def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None:
def on_sync_client_creation(*, client: k8s.CoreV1Api, operator: KubernetesPodOperator, **kwargs) -> None:
"""
Invoke this callback after creating the sync client.
Expand All @@ -50,7 +54,34 @@ def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None:
pass

@staticmethod
def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
def on_pod_manifest_created(
*,
pod_request: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback after KPO creates the V1Pod manifest but before the pod is created.
:param pod_request: the kubernetes pod manifest
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_pod_creation(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback after creating the pod.
Expand All @@ -61,7 +92,15 @@ def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs)
pass

@staticmethod
def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
def on_pod_starting(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback when the pod starts.
Expand All @@ -72,7 +111,15 @@ def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs)
pass

@staticmethod
def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None:
def on_pod_completion(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback when the pod completes.
Expand All @@ -83,7 +130,34 @@ def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwarg
pass

@staticmethod
def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs):
def on_pod_teardown(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback after all pod completion callbacks but before the pod is deleted.
:param pod: the completed pod.
:param client: the Kubernetes client that can be used in the callback.
:param mode: the current execution mode, it's one of (`sync`, `async`).
"""
pass

@staticmethod
def on_pod_cleanup(
*,
pod: k8s.V1Pod,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
):
"""
Invoke this callback after cleaning/deleting the pod.
Expand All @@ -95,7 +169,14 @@ def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs):

@staticmethod
def on_operator_resuming(
*, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, **kwargs
*,
pod: k8s.V1Pod,
event: dict,
client: client_type,
mode: str,
operator: KubernetesPodOperator,
context: Context,
**kwargs,
) -> None:
"""
Invoke this callback when resuming the `KubernetesPodOperator` from deferred state.
Expand Down
116 changes: 83 additions & 33 deletions providers/src/airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ def __init__(
is_delete_operator_pod: None | bool = None,
termination_message_policy: str = "File",
active_deadline_seconds: int | None = None,
callbacks: type[KubernetesPodOperatorCallback] | None = None,
callbacks: (
list[type[KubernetesPodOperatorCallback]] | type[KubernetesPodOperatorCallback] | None
) = None,
progress_callback: Callable[[str], None] | None = None,
logging_interval: int | None = None,
**kwargs,
Expand Down Expand Up @@ -415,7 +417,7 @@ def __init__(

self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict
self._progress_callback = progress_callback
self.callbacks = callbacks
self.callbacks = [] if not callbacks else callbacks if isinstance(callbacks, list) else [callbacks]
self._killed: bool = False

@cached_property
Expand Down Expand Up @@ -519,8 +521,9 @@ def hook(self) -> PodOperatorHookProtocol:
@cached_property
def client(self) -> CoreV1Api:
client = self.hook.core_v1_client
if self.callbacks:
self.callbacks.on_sync_client_creation(client=client)

for callback in self.callbacks:
callback.on_sync_client_creation(client=client, operator=self)
return client

def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None:
Expand Down Expand Up @@ -594,6 +597,14 @@ def execute_sync(self, context: Context):
try:
if self.pod_request_obj is None:
self.pod_request_obj = self.build_pod_request_obj(context)
for callback in self.callbacks:
callback.on_pod_manifest_created(
pod_request=self.pod_request_obj,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)
if self.pod is None:
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
pod_request_obj=self.pod_request_obj,
Expand All @@ -606,28 +617,48 @@ def execute_sync(self, context: Context):

# get remote pod for use in cleanup methods
self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context)
if self.callbacks:
self.callbacks.on_pod_creation(
pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC
for callback in self.callbacks:
callback.on_pod_creation(
pod=self.remote_pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

self.await_init_containers_completion(pod=self.pod)

self.await_pod_start(pod=self.pod)
if self.callbacks:
self.callbacks.on_pod_starting(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)
pod = self.find_pod(self.pod.metadata.namespace, context=context)
for callback in self.callbacks:
callback.on_pod_starting(
pod=pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

self.await_pod_completion(pod=self.pod)
if self.callbacks:
self.callbacks.on_pod_completion(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)
pod = self.find_pod(self.pod.metadata.namespace, context=context)
for callback in self.callbacks:
callback.on_pod_completion(
pod=pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)
for callback in self.callbacks:
callback.on_pod_teardown(
pod=pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

if self.do_xcom_push:
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
Expand All @@ -642,8 +673,14 @@ def execute_sync(self, context: Context):
pod=pod_to_clean,
remote_pod=self.remote_pod,
)
if self.callbacks:
self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC)
for callback in self.callbacks:
callback.on_pod_cleanup(
pod=pod_to_clean,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

if self.do_xcom_push:
return result
Expand Down Expand Up @@ -710,11 +747,15 @@ def execute_async(self, context: Context) -> None:
context=context,
)
if self.callbacks:
self.callbacks.on_pod_creation(
pod=self.find_pod(self.pod.metadata.namespace, context=context),
client=self.client,
mode=ExecutionMode.SYNC,
)
pod = self.find_pod(self.pod.metadata.namespace, context=context)
for callback in self.callbacks:
callback.on_pod_creation(
pod=pod,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
Expand Down Expand Up @@ -775,10 +816,16 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

if self.callbacks and event["status"] != "running":
self.callbacks.on_operator_resuming(
pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC
)
if event["status"] != "running":
for callback in self.callbacks:
callback.on_operator_resuming(
pod=self.pod,
event=event,
client=self.client,
mode=ExecutionMode.SYNC,
context=context,
operator=self,
)

follow = self.logging_interval is None
last_log_time = event.get("last_log_time")
Expand Down Expand Up @@ -821,9 +868,9 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
except TaskDeferred:
raise
finally:
self._clean(event)
self._clean(event, context)

def _clean(self, event: dict[str, Any]) -> None:
def _clean(self, event: dict[str, Any], context: Context) -> None:
if event["status"] == "running":
return
istio_enabled = self.is_istio_enabled(self.pod)
Expand All @@ -846,6 +893,7 @@ def _clean(self, event: dict[str, Any]) -> None:
self.post_complete_action(
pod=self.pod,
remote_pod=self.pod,
context=context,
)

def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None) -> None:
Expand Down Expand Up @@ -875,14 +923,16 @@ def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime
e if not isinstance(e, ApiException) else e.reason,
)

def post_complete_action(self, *, pod, remote_pod, **kwargs) -> None:
def post_complete_action(self, *, pod, remote_pod, context: Context, **kwargs) -> None:
"""Actions that must be done after operator finishes logic of the deferrable_execution."""
self.cleanup(
pod=pod,
remote_pod=remote_pod,
)
if self.callbacks:
self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC)
for callback in self.callbacks:
callback.on_pod_cleanup(
pod=pod, client=self.client, mode=ExecutionMode.SYNC, operator=self, context=context
)

def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
# Skip cleaning the pod in the following scenarios.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ class PodManager(LoggingMixin):
def __init__(
self,
kube_client: client.CoreV1Api,
callbacks: type[KubernetesPodOperatorCallback] | None = None,
callbacks: list[type[KubernetesPodOperatorCallback]] | None = None,
):
"""
Create the launcher.
Expand All @@ -331,7 +331,7 @@ def __init__(
super().__init__()
self._client = kube_client
self._watch = watch.Watch()
self._callbacks = callbacks
self._callbacks = callbacks or []

def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
"""Run POD asynchronously."""
Expand Down Expand Up @@ -466,8 +466,8 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None
progress_callback_lines.append(line)
else: # previous log line is complete
for line in progress_callback_lines:
if self._callbacks:
self._callbacks.progress_callback(
for callback in self._callbacks:
callback.progress_callback(
line=line, client=self._client, mode=ExecutionMode.SYNC
)
if message_to_log is not None:
Expand All @@ -485,8 +485,8 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None
finally:
# log the last line and update the last_captured_timestamp
for line in progress_callback_lines:
if self._callbacks:
self._callbacks.progress_callback(
for callback in self._callbacks:
callback.progress_callback(
line=line, client=self._client, mode=ExecutionMode.SYNC
)
if message_to_log is not None:
Expand Down
Loading

0 comments on commit 9351bd0

Please sign in to comment.