Skip to content

Commit

Permalink
AIP-51 - Executor Coupling in Logging (#28161)
Browse files Browse the repository at this point in the history
Executors may now implement a method to vend task logs
  • Loading branch information
snjypl authored Jan 24, 2023
1 parent 1fbfd31 commit 3b25168
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 139 deletions.
9 changes: 9 additions & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,15 @@ def execute_async(
"""
raise NotImplementedError()

def get_task_log(self, ti: TaskInstance, log: str = "") -> None | str | tuple[str, dict[str, bool]]:
"""
This method can be implemented by any child class to return the task logs.
:param ti: A TaskInstance object
:param log: log str
:return: logs or tuple of logs and meta dict
"""

def end(self) -> None: # pragma: no cover
"""Wait synchronously for the previously submitted job to complete."""
raise NotImplementedError()
Expand Down
6 changes: 6 additions & 0 deletions airflow/executors/celery_kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def queue_task_instance(
cfg_path=cfg_path,
)

def get_task_log(self, ti: TaskInstance, log: str = "") -> None | str | tuple[str, dict[str, bool]]:
"""Fetch task log from Kubernetes executor"""
if ti.queue == self.kubernetes_executor.kubernetes_queue:
return self.kubernetes_executor.get_task_log(ti=ti, log=log)
return None

def has_task(self, task_instance: TaskInstance) -> bool:
"""
Checks if a task is either queued or running in either celery or kubernetes executor.
Expand Down
53 changes: 53 additions & 0 deletions airflow/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import multiprocessing
import time
from collections import defaultdict
from contextlib import suppress
from datetime import timedelta
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple
Expand All @@ -37,6 +38,7 @@
from kubernetes.client.rest import ApiException
from urllib3.exceptions import ReadTimeoutError

from airflow.configuration import conf
from airflow.exceptions import AirflowException, PodMutationHookException, PodReconciliationError
from airflow.executors.base_executor import BaseExecutor, CommandType
from airflow.kubernetes import pod_generator
Expand Down Expand Up @@ -771,6 +773,57 @@ def _change_state(self, key: TaskInstanceKey, state: str | None, pod_id: str, na
# do this once, so only do it when we remove the task from running
self.event_buffer[key] = state, None

@staticmethod
def _get_pod_namespace(ti: TaskInstance):
pod_override = ti.executor_config.get("pod_override")
namespace = None
with suppress(Exception):
namespace = pod_override.metadata.namespace
return namespace or conf.get("kubernetes_executor", "namespace", fallback="default")

def get_task_log(self, ti: TaskInstance, log: str = "") -> str | tuple[str, dict[str, bool]]:

try:
from airflow.kubernetes.pod_generator import PodGenerator

client = get_kube_client()

log += f"*** Trying to get logs (last 100 lines) from worker pod {ti.hostname} ***\n\n"
selector = PodGenerator.build_selector_for_k8s_executor_pod(
dag_id=ti.dag_id,
task_id=ti.task_id,
try_number=ti.try_number,
map_index=ti.map_index,
run_id=ti.run_id,
airflow_worker=ti.queued_by_job_id,
)
namespace = self._get_pod_namespace(ti)
pod_list = client.list_namespaced_pod(
namespace=namespace,
label_selector=selector,
).items
if not pod_list:
raise RuntimeError("Cannot find pod for ti %s", ti)
elif len(pod_list) > 1:
raise RuntimeError("Found multiple pods for ti %s: %s", ti, pod_list)
res = client.read_namespaced_pod_log(
name=pod_list[0].metadata.name,
namespace=namespace,
container="base",
follow=False,
tail_lines=100,
_preload_content=False,
)

for line in res:
log += line.decode()

return log

except Exception as f:
log += f"*** Unable to fetch logs from worker pod {ti.hostname} ***\n{str(f)}\n\n"
return log, {"end_of_log": True}

def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
tis_to_flush = [ti for ti in tis if not ti.queued_by_job_id]
scheduler_job_ids = {ti.queued_by_job_id for ti in tis}
Expand Down
7 changes: 7 additions & 0 deletions airflow/executors/local_kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ def queue_task_instance(
cfg_path=cfg_path,
)

def get_task_log(self, ti: TaskInstance, log: str = "") -> None | str | tuple[str, dict[str, bool]]:
"""Fetch task log from kubernetes executor"""
if ti.queue == self.kubernetes_executor.kubernetes_queue:
return self.kubernetes_executor.get_task_log(ti=ti, log=log)

return None

def has_task(self, task_instance: TaskInstance) -> bool:
"""
Checks if a task is either queued or running in either local or kubernetes executor.
Expand Down
166 changes: 64 additions & 102 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin

from airflow.configuration import AirflowConfigException, conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, RemovedInAirflow3Warning
from airflow.executors.executor_loader import ExecutorLoader
from airflow.utils.context import Context
from airflow.utils.helpers import parse_template_string, render_template_to_string
from airflow.utils.log.logging_mixin import SetContextPropagate
Expand Down Expand Up @@ -146,23 +147,54 @@ def _render_filename(self, ti: TaskInstance, try_number: int) -> str:
def _read_grouped_logs(self):
return False

@staticmethod
def _should_check_k8s(queue):
"""
If the task is running through kubernetes executor, return True.
def _get_task_log_from_worker(
self, ti: TaskInstance, log: str, log_relative_path: str
) -> str | tuple[str, dict[str, bool]]:
import httpx

When logs aren't available locally, in this case we read from k8s pod logs.
"""
executor = conf.get("core", "executor")
if executor == "KubernetesExecutor":
return True
elif executor == "LocalKubernetesExecutor":
if queue == conf.get("local_kubernetes_executor", "kubernetes_queue"):
return True
elif executor == "CeleryKubernetesExecutor":
if queue == conf.get("celery_kubernetes_executor", "kubernetes_queue"):
return True
return False
from airflow.utils.jwt_signer import JWTSigner

url = self._get_log_retrieval_url(ti, log_relative_path)
log += f"*** Fetching from: {url}\n"

try:
timeout = None # No timeout
try:
timeout = conf.getint("webserver", "log_fetch_timeout_sec")
except (AirflowConfigException, ValueError):
pass

signer = JWTSigner(
secret_key=conf.get("webserver", "secret_key"),
expiration_time_in_seconds=conf.getint("webserver", "log_request_clock_grace", fallback=30),
audience="task-instance-logs",
)
response = httpx.get(
url,
timeout=timeout,
headers={"Authorization": signer.generate_signed_token({"filename": log_relative_path})},
)
response.encoding = "utf-8"

if response.status_code == 403:
log += (
"*** !!!! Please make sure that all your Airflow components (e.g. "
"schedulers, webservers and workers) have "
"the same 'secret_key' configured in 'webserver' section and "
"time is synchronized on all your machines (for example with ntpd) !!!!!\n***"
)
log += (
"*** See more at https://airflow.apache.org/docs/apache-airflow/"
"stable/configurations-ref.html#secret-key\n***"
)
# Check if the resource was properly fetched
response.raise_for_status()

log += "\n" + response.text
return log
except Exception as e:
log += f"*** Failed to fetch log file from worker. {str(e)}\n"
return log, {"end_of_log": True}

def _read(self, ti: TaskInstance, try_number: int, metadata: dict[str, Any] | None = None):
"""
Expand All @@ -186,8 +218,6 @@ def _read(self, ti: TaskInstance, try_number: int, metadata: dict[str, Any] | No
This is determined by the status of the TaskInstance
log_pos: (absolute) Char position to which the log is retrieved
"""
from airflow.utils.jwt_signer import JWTSigner

# Task instance here might be different from task instance when
# initializing the handler. Thus explicitly getting log location
# is needed to get correct log path.
Expand All @@ -204,91 +234,23 @@ def _read(self, ti: TaskInstance, try_number: int, metadata: dict[str, Any] | No
log = f"*** Failed to load local log file: {location}\n"
log += f"*** {str(e)}\n"
return log, {"end_of_log": True}
elif self._should_check_k8s(ti.queue):
try:
from airflow.kubernetes.kube_client import get_kube_client
from airflow.kubernetes.pod_generator import PodGenerator

client = get_kube_client()

log += f"*** Trying to get logs (last 100 lines) from worker pod {ti.hostname} ***\n\n"
selector = PodGenerator.build_selector_for_k8s_executor_pod(
dag_id=ti.dag_id,
task_id=ti.task_id,
try_number=ti.try_number,
map_index=ti.map_index,
run_id=ti.run_id,
airflow_worker=ti.queued_by_job_id,
)
namespace = self._get_pod_namespace(ti)
pod_list = client.list_namespaced_pod(
namespace=namespace,
label_selector=selector,
).items
if not pod_list:
raise RuntimeError("Cannot find pod for ti %s", ti)
elif len(pod_list) > 1:
raise RuntimeError("Found multiple pods for ti %s: %s", ti, pod_list)
res = client.read_namespaced_pod_log(
name=pod_list[0].metadata.name,
namespace=namespace,
container="base",
follow=False,
tail_lines=100,
_preload_content=False,
)
else:
log += f"*** Local log file does not exist: {location}\n"
executor = ExecutorLoader.get_default_executor()
task_log = None

for line in res:
log += line.decode()
task_log = executor.get_task_log(ti=ti, log=log)
if isinstance(task_log, tuple):
return task_log

except Exception as f:
log += f"*** Unable to fetch logs from worker pod {ti.hostname} ***\n{str(f)}\n\n"
return log, {"end_of_log": True}
else:
import httpx
if task_log is None:
log += "*** Failed to fetch log from executor. Falling back to fetching log from worker.\n"
task_log = self._get_task_log_from_worker(ti, log, log_relative_path=log_relative_path)

url = self._get_log_retrieval_url(ti, log_relative_path)
log += f"*** Log file does not exist: {location}\n"
log += f"*** Fetching from: {url}\n"
try:
timeout = None # No timeout
try:
timeout = conf.getint("webserver", "log_fetch_timeout_sec")
except (AirflowConfigException, ValueError):
pass

signer = JWTSigner(
secret_key=conf.get("webserver", "secret_key"),
expiration_time_in_seconds=conf.getint(
"webserver", "log_request_clock_grace", fallback=30
),
audience="task-instance-logs",
)
response = httpx.get(
url,
timeout=timeout,
headers={"Authorization": signer.generate_signed_token({"filename": log_relative_path})},
)
response.encoding = "utf-8"

if response.status_code == 403:
log += (
"*** !!!! Please make sure that all your Airflow components (e.g. "
"schedulers, webservers and workers) have "
"the same 'secret_key' configured in 'webserver' section and "
"time is synchronized on all your machines (for example with ntpd) !!!!!\n***"
)
log += (
"*** See more at https://airflow.apache.org/docs/apache-airflow/"
"stable/configurations-ref.html#secret-key\n***"
)
# Check if the resource was properly fetched
response.raise_for_status()

log += "\n" + response.text
except Exception as e:
log += f"*** Failed to fetch log file from worker. {str(e)}\n"
return log, {"end_of_log": True}
if isinstance(task_log, tuple):
return task_log

log = str(task_log)

# Process tailing if log is not at it's end
end_of_log = ti.try_number != try_number or ti.state not in [State.RUNNING, State.DEFERRED]
Expand Down
8 changes: 7 additions & 1 deletion tests/executors/test_base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.utils import timezone
from airflow.utils.state import State

Expand All @@ -44,6 +44,12 @@ def test_is_local_default_value():
assert not BaseExecutor.is_local


def test_get_task_log():
executor = BaseExecutor()
ti = TaskInstance(task=BaseOperator(task_id="dummy"))
assert executor.get_task_log(ti=ti) is None


def test_serve_logs_default_value():
assert not BaseExecutor.serve_logs

Expand Down
16 changes: 16 additions & 0 deletions tests/executors/test_celery_kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,22 @@ def mock_ti(queue):
celery_executor_mock.try_adopt_task_instances.assert_called_once_with(celery_tis)
k8s_executor_mock.try_adopt_task_instances.assert_called_once_with(k8s_tis)

def test_log_is_fetched_from_k8s_executor_only_for_k8s_queue(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
simple_task_instance = mock.MagicMock()
simple_task_instance.queue = KUBERNETES_QUEUE
cke.get_task_log(ti=simple_task_instance, log="")
k8s_executor_mock.get_task_log.assert_called_once_with(ti=simple_task_instance, log=mock.ANY)

k8s_executor_mock.reset_mock()

simple_task_instance.queue = "test-queue"
log = cke.get_task_log(ti=simple_task_instance, log="")
k8s_executor_mock.get_task_log.assert_not_called()
assert log is None

def test_get_event_buffer(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
Expand Down
Loading

0 comments on commit 3b25168

Please sign in to comment.