diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 57568af199710..5a5cf2d73f15d 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -26,6 +26,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple import pendulum +from deprecated import deprecated from airflow.cli.cli_config import DefaultHelpParser from airflow.configuration import conf @@ -545,7 +546,12 @@ def terminate(self): """Get called when the daemon receives a SIGTERM.""" raise NotImplementedError() - def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: # pragma: no cover + @deprecated( + reason="Replaced by function `revoke_task`.", + category=RemovedInAirflow3Warning, + action="ignore", + ) + def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: """ Handle remnants of tasks that were failed because they were stuck in queued. @@ -556,7 +562,23 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: # p :param tis: List of Task Instances to clean up :return: List of readable task instances for a warning message """ - raise NotImplementedError() + raise NotImplementedError + + def revoke_task(self, *, ti: TaskInstance): + """ + Attempt to remove task from executor. + + It should attempt to ensure that the task is no longer running on the worker, + and ensure that it is cleared out from internal data structures. + + It should *not* change the state of the task in airflow, or add any events + to the event buffer. + + It should not raise any error. + + :param ti: Task instance to remove + """ + raise NotImplementedError def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index aa4e8d4f26aea..c9afd40f719ed 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -25,12 +25,14 @@ import time import warnings from collections import Counter, defaultdict, deque +from contextlib import suppress from dataclasses import dataclass from datetime import timedelta from functools import lru_cache, partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Iterator +from deprecated import deprecated from sqlalchemy import and_, delete, func, not_, or_, select, text, update from sqlalchemy.exc import OperationalError from sqlalchemy.orm import lazyload, load_only, make_transient, selectinload @@ -97,6 +99,9 @@ DR = DagRun DM = DagModel +TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule" +""":meta private:""" + @dataclass class ConcurrencyMap: @@ -228,6 +233,13 @@ def __init__( stalled_task_timeout, task_adoption_timeout, worker_pods_pending_timeout, task_queued_timeout ) + # this param is intentionally undocumented + self._num_stuck_queued_retries = conf.getint( + section="scheduler", + key="num_stuck_in_queued_retries", + fallback=2, + ) + self.do_pickle = do_pickle if log: @@ -1093,7 +1105,7 @@ def _run_scheduler_loop(self) -> None: timers.call_regular_interval( conf.getfloat("scheduler", "task_queued_timeout_check_interval"), - self._fail_tasks_stuck_in_queued, + self._handle_tasks_stuck_in_queued, ) timers.call_regular_interval( @@ -1141,6 +1153,7 @@ def _run_scheduler_loop(self) -> None: for executor in self.job.executors: try: # this is backcompat check if executor does not inherit from BaseExecutor + # todo: remove in airflow 3.0 if not hasattr(executor, "_task_event_logs"): continue with create_session() as session: @@ -1772,48 +1785,132 @@ def _send_sla_callbacks_to_processor(self, dag: DAG) -> None: self.job.executor.send_callback(request) @provide_session - def _fail_tasks_stuck_in_queued(self, session: Session = NEW_SESSION) -> None: + def _handle_tasks_stuck_in_queued(self, session: Session = NEW_SESSION) -> None: """ - Mark tasks stuck in queued for longer than `task_queued_timeout` as failed. + Handle the scenario where a task is queued for longer than `task_queued_timeout`. Tasks can get stuck in queued for a wide variety of reasons (e.g. celery loses track of a task, a cluster can't further scale up its workers, etc.), but tasks - should not be stuck in queued for a long time. This will mark tasks stuck in - queued for longer than `self._task_queued_timeout` as failed. If the task has - available retries, it will be retried. + should not be stuck in queued for a long time. + + We will attempt to requeue the task (by revoking it from executor and setting to + scheduled) up to 2 times before failing the task. """ - self.log.debug("Calling SchedulerJob._fail_tasks_stuck_in_queued method") + tasks_stuck_in_queued = self._get_tis_stuck_in_queued(session) + for executor, stuck_tis in self._executor_to_tis(tasks_stuck_in_queued).items(): + try: + for ti in stuck_tis: + executor.revoke_task(ti=ti) + self._maybe_requeue_stuck_ti( + ti=ti, + session=session, + ) + except NotImplementedError: + # this block only gets entered if the executor has not implemented `revoke_task`. + # in which case, we try the fallback logic + # todo: remove the call to _stuck_in_queued_backcompat_logic in airflow 3.0. + # after 3.0, `cleanup_stuck_queued_tasks` will be removed, so we should + # just continue immediately. + self._stuck_in_queued_backcompat_logic(executor, stuck_tis) + continue - tasks_stuck_in_queued = session.scalars( + def _get_tis_stuck_in_queued(self, session) -> Iterable[TaskInstance]: + """Query db for TIs that are stuck in queued.""" + return session.scalars( select(TI).where( TI.state == TaskInstanceState.QUEUED, TI.queued_dttm < (timezone.utcnow() - timedelta(seconds=self._task_queued_timeout)), TI.queued_by_job_id == self.job.id, ) - ).all() + ) - for executor, stuck_tis in self._executor_to_tis(tasks_stuck_in_queued).items(): - try: - cleaned_up_task_instances = set(executor.cleanup_stuck_queued_tasks(tis=stuck_tis)) - for ti in stuck_tis: - if repr(ti) in cleaned_up_task_instances: - self.log.warning( - "Marking task instance %s stuck in queued as failed. " - "If the task instance has available retries, it will be retried.", - ti, - ) - session.add( - Log( - event="stuck in queued", - task_instance=ti.key, - extra=( - "Task will be marked as failed. If the task instance has " - "available retries, it will be retried." - ), - ) - ) - except NotImplementedError: - self.log.debug("Executor doesn't support cleanup of stuck queued tasks. Skipping.") + def _maybe_requeue_stuck_ti(self, *, ti, session): + """ + Requeue task if it has not been attempted too many times. + + Otherwise, fail it. + """ + num_times_stuck = self._get_num_times_stuck_in_queued(ti, session) + if num_times_stuck < self._num_stuck_queued_retries: + self.log.info("Task stuck in queued; will try to requeue. task_id=%s", ti.task_id) + session.add( + Log( + event=TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT, + task_instance=ti.key, + extra=( + f"Task was in queued state for longer than {self._task_queued_timeout} " + "seconds; task state will be set back to scheduled." + ), + ) + ) + self._reschedule_stuck_task(ti) + else: + self.log.info( + "Task requeue attempts exceeded max; marking failed. task_instance=%s", + ti, + ) + session.add( + Log( + event="stuck in queued tries exceeded", + task_instance=ti.key, + extra=f"Task was requeued more than {self._num_stuck_queued_retries} times and will be failed.", + ) + ) + ti.set_state(TaskInstanceState.FAILED, session=session) + + @deprecated( + reason="This is backcompat layer for older executor interface. Should be removed in 3.0", + category=RemovedInAirflow3Warning, + action="ignore", + ) + def _stuck_in_queued_backcompat_logic(self, executor, stuck_tis): + """ + Try to invoke stuck in queued cleanup for older executor interface. + + TODO: remove in airflow 3.0 + + Here we handle case where the executor pre-dates the interface change that + introduced `cleanup_tasks_stuck_in_queued` and deprecated `cleanup_stuck_queued_tasks`. + + """ + with suppress(NotImplementedError): + for ti_repr in executor.cleanup_stuck_queued_tasks(tis=stuck_tis): + self.log.warning( + "Task instance %s stuck in queued. Will be set to failed.", + ti_repr, + ) + + @provide_session + def _reschedule_stuck_task(self, ti, session=NEW_SESSION): + session.execute( + update(TI) + .where(TI.filter_for_tis([ti])) + .values( + state=TaskInstanceState.SCHEDULED, + queued_dttm=None, + ) + .execution_options(synchronize_session=False) + ) + + @provide_session + def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NEW_SESSION) -> int: + """ + Check the Log table to see how many times a taskinstance has been stuck in queued. + + We can then use this information to determine whether to reschedule a task or fail it. + """ + return ( + session.query(Log) + .where( + Log.task_id == ti.task_id, + Log.dag_id == ti.dag_id, + Log.run_id == ti.run_id, + Log.map_index == ti.map_index, + Log.try_number == ti.try_number, + Log.event == TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT, + ) + .count() + ) @provide_session def _emit_pool_metrics(self, session: Session = NEW_SESSION) -> None: @@ -2102,7 +2199,7 @@ def _orphan_unreferenced_datasets(self, session: Session = NEW_SESSION) -> None: updated_count = sum(self._set_orphaned(dataset) for dataset in orphaned_dataset_query) Stats.gauge("dataset.orphaned", updated_count) - def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor, list[TaskInstance]]: + def _executor_to_tis(self, tis: Iterable[TaskInstance]) -> dict[BaseExecutor, list[TaskInstance]]: """Organize TIs into lists per their respective executor.""" _executor_to_tis: defaultdict[BaseExecutor, list[TaskInstance]] = defaultdict(list) for ti in tis: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1ca4c8dc4557d..6f8d5015ed4d5 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1359,6 +1359,7 @@ repos repr req reqs +requeued Reserialize reserialize reserialized diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index d65346857932d..97d3e9fe87d0b 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -28,6 +28,7 @@ from typing import Generator from unittest import mock from unittest.mock import MagicMock, PropertyMock, patch +from uuid import uuid4 import psutil import pytest @@ -55,6 +56,7 @@ from airflow.models.dagrun import DagRun from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel from airflow.models.db_callback_request import DbCallbackRequest +from airflow.models.log import Log from airflow.models.pool import Pool from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey @@ -123,6 +125,19 @@ def load_examples(): # Patch the MockExecutor into the dict of known executors in the Loader +@contextlib.contextmanager +def _loader_mock(mock_executors): + with mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock: + # The executors are mocked, so cannot be loaded/imported. Mock load_executor and return the + # correct object for the given input executor name. + loader_mock.side_effect = lambda *x: { + ("default_exec",): mock_executors[0], + (None,): mock_executors[0], + ("secondary_exec",): mock_executors[1], + }[x] + yield + + @patch.dict( ExecutorLoader.executors, {MOCK_EXECUTOR: f"{MockExecutor.__module__}.{MockExecutor.__qualname__}"} ) @@ -2177,7 +2192,18 @@ def test_adopt_or_reset_orphaned_tasks_multiple_executors(self, dag_maker, mock_ # Second executor called for ti3 mock_executors[1].try_adopt_task_instances.assert_called_once_with([ti3]) - def test_fail_stuck_queued_tasks(self, dag_maker, session, mock_executors): + def test_handle_stuck_queued_tasks_backcompat(self, dag_maker, session, mock_executors): + """ + Verify backward compatibility of the executor interface w.r.t. stuck queued. + + Prior to #43520, scheduler called method `cleanup_stuck_queued_tasks`, which failed tis. + + After #43520, scheduler calls `cleanup_tasks_stuck_in_queued`, which requeues tis. + + At Airflow 3.0, we should remove backcompat support for this old function. But for now + we verify that we call it as a fallback. + """ + # todo: remove in airflow 3.0 with dag_maker("test_fail_stuck_queued_tasks_multiple_executors"): op1 = EmptyOperator(task_id="op1") op2 = EmptyOperator(task_id="op2", executor="default_exec") @@ -2194,26 +2220,102 @@ def test_fail_stuck_queued_tasks(self, dag_maker, session, mock_executors): scheduler_job = Job() job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=0) job_runner._task_queued_timeout = 300 + mock_exec_1 = mock_executors[0] + mock_exec_2 = mock_executors[1] + mock_exec_1.revoke_task.side_effect = NotImplementedError + mock_exec_2.revoke_task.side_effect = NotImplementedError with mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock: # The executors are mocked, so cannot be loaded/imported. Mock load_executor and return the # correct object for the given input executor name. loader_mock.side_effect = lambda *x: { - ("default_exec",): mock_executors[0], - (None,): mock_executors[0], - ("secondary_exec",): mock_executors[1], + ("default_exec",): mock_exec_1, + (None,): mock_exec_1, + ("secondary_exec",): mock_exec_2, }[x] - job_runner._fail_tasks_stuck_in_queued() + job_runner._handle_tasks_stuck_in_queued() # Default executor is called for ti1 (no explicit executor override uses default) and ti2 (where we # explicitly marked that for execution by the default executor) try: - mock_executors[0].cleanup_stuck_queued_tasks.assert_called_once_with(tis=[ti1, ti2]) + mock_exec_1.cleanup_stuck_queued_tasks.assert_called_once_with(tis=[ti1, ti2]) except AssertionError: - mock_executors[0].cleanup_stuck_queued_tasks.assert_called_once_with(tis=[ti2, ti1]) - mock_executors[1].cleanup_stuck_queued_tasks.assert_called_once_with(tis=[ti3]) + mock_exec_1.cleanup_stuck_queued_tasks.assert_called_once_with(tis=[ti2, ti1]) + mock_exec_2.cleanup_stuck_queued_tasks.assert_called_once_with(tis=[ti3]) + + @conf_vars({("scheduler", "num_stuck_in_queued_retries"): "2"}) + def test_handle_stuck_queued_tasks_multiple_attempts(self, dag_maker, session, mock_executors): + """Verify that tasks stuck in queued will be rescheduled up to N times.""" + with dag_maker("test_fail_stuck_queued_tasks_multiple_executors"): + EmptyOperator(task_id="op1") + EmptyOperator(task_id="op2", executor="default_exec") + + def _queue_tasks(tis): + for ti in tis: + ti.state = "queued" + ti.queued_dttm = timezone.utcnow() + session.commit() + + run_id = str(uuid4()) + dr = dag_maker.create_dagrun(run_id=run_id) + + tis = dr.get_task_instances(session=session) + _queue_tasks(tis=tis) + scheduler_job = Job() + scheduler = SchedulerJobRunner(job=scheduler_job, num_runs=0) + # job_runner._reschedule_stuck_task = MagicMock() + scheduler._task_queued_timeout = -300 # always in violation of timeout + + with _loader_mock(mock_executors): + scheduler._handle_tasks_stuck_in_queued(session=session) + + # If the task gets stuck in queued once, we reset it to scheduled + tis = dr.get_task_instances(session=session) + assert [x.state for x in tis] == ["scheduled", "scheduled"] + assert [x.queued_dttm for x in tis] == [None, None] + + _queue_tasks(tis=tis) + log_events = [x.event for x in session.scalars(select(Log).where(Log.run_id == run_id)).all()] + assert log_events == [ + "stuck in queued reschedule", + "stuck in queued reschedule", + ] + + with _loader_mock(mock_executors): + scheduler._handle_tasks_stuck_in_queued(session=session) + session.commit() + + log_events = [x.event for x in session.scalars(select(Log).where(Log.run_id == run_id)).all()] + assert log_events == [ + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + ] + mock_executors[0].fail.assert_not_called() + tis = dr.get_task_instances(session=session) + assert [x.state for x in tis] == ["scheduled", "scheduled"] + _queue_tasks(tis=tis) + + with _loader_mock(mock_executors): + scheduler._handle_tasks_stuck_in_queued(session=session) + session.commit() + log_events = [x.event for x in session.scalars(select(Log).where(Log.run_id == run_id)).all()] + assert log_events == [ + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued reschedule", + "stuck in queued tries exceeded", + "stuck in queued tries exceeded", + ] + + mock_executors[0].fail.assert_not_called() # just demoing that we don't fail with executor method + states = [x.state for x in dr.get_task_instances(session=session)] + assert states == ["failed", "failed"] - def test_fail_stuck_queued_tasks_raises_not_implemented(self, dag_maker, session, caplog): + def test_revoke_task_not_imp_tolerated(self, dag_maker, session, caplog): + """Test that if executor no implement revoke_task then we don't blow up.""" with dag_maker("test_fail_stuck_queued_tasks"): op1 = EmptyOperator(task_id="op1") @@ -2224,12 +2326,14 @@ def test_fail_stuck_queued_tasks_raises_not_implemented(self, dag_maker, session session.commit() from airflow.executors.local_executor import LocalExecutor + assert "revoke_task" in BaseExecutor.__dict__ + # this is just verifying that LocalExecutor is good enough for this test + # in that it does not implement revoke_task + assert "revoke_task" not in LocalExecutor.__dict__ scheduler_job = Job(executor=LocalExecutor()) job_runner = SchedulerJobRunner(job=scheduler_job, num_runs=0) job_runner._task_queued_timeout = 300 - with caplog.at_level(logging.DEBUG): - job_runner._fail_tasks_stuck_in_queued() - assert "Executor doesn't support cleanup of stuck queued tasks. Skipping." in caplog.text + job_runner._handle_tasks_stuck_in_queued() @mock.patch("airflow.dag_processing.manager.DagFileProcessorAgent") def test_executor_end_called(self, mock_processor_agent, mock_executors):