Skip to content

Commit

Permalink
Properly check the existence of missing mapped TIs (#25788)
Browse files Browse the repository at this point in the history
The previous implementation of missing indexes was not correct. Missing indexes
were being checked every time that `task_instance_scheduling_decision` was called.
The missing tasks should only be revised after expanding of last resort for mapped tasks have been done. If we find that a task is in schedulable state and has already been expanded, we revise its indexes and ensure they are complete. Missing indexes are marked as removed.
This implementation allows the revision to be done in one place

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
ephraimbuddy and uranusjr authored Aug 25, 2022
1 parent 29c3316 commit db818ae
Show file tree
Hide file tree
Showing 2 changed files with 440 additions and 80 deletions.
131 changes: 54 additions & 77 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,6 @@ def _filter_tis_and_exclude_removed(dag: "DAG", tis: List[TI]) -> Iterable[TI]:
yield ti

tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
missing_indexes = self._revise_mapped_task_indexes(tis, session=session)
if missing_indexes:
self.verify_integrity(missing_indexes=missing_indexes, session=session)

unfinished_tis = [t for t in tis if t.state in State.unfinished]
finished_tis = [t for t in tis if t.state in State.finished]
Expand Down Expand Up @@ -730,6 +727,11 @@ def _get_ready_tis(
additional_tis.extend(expanded_tis[1:])
expansion_happened = True
if schedulable.state in SCHEDULEABLE_STATES:
task = schedulable.task
if isinstance(schedulable.task, MappedOperator):
# Ensure the task indexes are complete
created = self._revise_mapped_task_indexes(task, session=session)
ready_tis.extend(created)
ready_tis.append(schedulable)

# Check if any ti changed state
Expand Down Expand Up @@ -825,7 +827,6 @@ def _emit_duration_stats_for_finished_state(self):
def verify_integrity(
self,
*,
missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]] = None,
session: Session = NEW_SESSION,
):
"""
Expand All @@ -842,15 +843,10 @@ def verify_integrity(

dag = self.get_dag()
task_ids: Set[str] = set()
if missing_indexes:
tis = self.get_task_instances(session=session)
for ti in tis:
task_instance_mutation_hook(ti)
task_ids.add(ti.task_id)
else:
task_ids, missing_indexes = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)

task_ids = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)

def task_filter(task: "Operator") -> bool:
return task.task_id not in task_ids and (
Expand All @@ -865,29 +861,27 @@ def task_filter(task: "Operator") -> bool:
task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)

# Create the missing tasks, including mapped tasks
tasks = self._create_missing_tasks(dag, task_creator, task_filter, missing_indexes, session=session)
tasks = self._create_tasks(dag, task_creator, task_filter, session=session)

self._create_task_instances(dag.dag_id, tasks, created_counts, hook_is_noop, session=session)

def _check_for_removed_or_restored_tasks(
self, dag: "DAG", ti_mutation_hook, *, session: Session
) -> Tuple[Set[str], Dict["MappedOperator", Sequence[int]]]:
) -> Set[str]:
"""
Check for removed tasks/restored/missing tasks.
:param dag: DAG object corresponding to the dagrun
:param ti_mutation_hook: task_instance_mutation_hook function
:param session: Sqlalchemy ORM Session
:return: List of task_ids in the dagrun and missing task indexes
:return: Task IDs in the DAG run
"""
tis = self.get_task_instances(session=session)

# check for removed or restored tasks
task_ids = set()
existing_indexes: Dict["MappedOperator", List[int]] = defaultdict(list)
expected_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
for ti in tis:
ti_mutation_hook(ti)
task_ids.add(ti.task_id)
Expand Down Expand Up @@ -925,13 +919,9 @@ def _check_for_removed_or_restored_tasks(
elif ti.map_index < 0:
self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
ti.state = State.REMOVED
else:
self.log.info("Restoring mapped task '%s'", ti)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
existing_indexes[task].append(ti.map_index)
expected_indexes[task] = range(num_mapped_tis)
else:
# What if it is _now_ dynamically mapped, but wasn't before?
task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.run_time_mapped_ti_count(self.run_id, session=session)

if total_length is None:
Expand All @@ -950,16 +940,8 @@ def _check_for_removed_or_restored_tasks(
total_length,
)
ti.state = State.REMOVED
else:
self.log.info("Restoring mapped task '%s'", ti)
Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
existing_indexes[task].append(ti.map_index)
expected_indexes[task] = range(total_length)
# Check if we have some missing indexes to create ti for
missing_indexes: Dict["MappedOperator", Sequence[int]] = defaultdict(list)
for k, v in existing_indexes.items():
missing_indexes.update({k: list(set(expected_indexes[k]).difference(v))})
return task_ids, missing_indexes

return task_ids

def _get_task_creator(
self, created_counts: Dict[str, int], ti_mutation_hook: Callable, hook_is_noop: bool
Expand Down Expand Up @@ -995,12 +977,11 @@ def create_ti(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
creator = create_ti
return creator

def _create_missing_tasks(
def _create_tasks(
self,
dag: "DAG",
task_creator: Callable,
task_filter: Callable,
missing_indexes: Optional[Dict["MappedOperator", Sequence[int]]],
*,
session: Session,
) -> Iterable["Operator"]:
Expand Down Expand Up @@ -1031,12 +1012,7 @@ def expand_mapped_literals(
tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))

tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, tasks_and_map_idxs))
if missing_indexes:
# If there are missing indexes, override the tasks to create
new_tasks_and_map_idxs = itertools.starmap(
expand_mapped_literals, [(k, v) for k, v in missing_indexes.items() if len(v) > 0]
)
tasks = itertools.chain.from_iterable(itertools.starmap(task_creator, new_tasks_and_map_idxs))

return tasks

def _create_task_instances(
Expand Down Expand Up @@ -1082,44 +1058,45 @@ def _create_task_instances(
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()

def _revise_mapped_task_indexes(
self,
tis: Iterable[TI],
*,
session: Session,
) -> Dict["MappedOperator", Sequence[int]]:
"""Check if the length of the mapped task instances changed at runtime and find the missing indexes.
def _revise_mapped_task_indexes(self, task, session: Session):
"""Check if task increased or reduced in length and handle appropriately"""
from airflow.models.taskinstance import TaskInstance
from airflow.settings import task_instance_mutation_hook

:param tis: Task instances to check
:param session: The session to use
"""
from airflow.models.mappedoperator import MappedOperator
task.run_time_mapped_ti_count.cache_clear()
total_length = (
task.parse_time_mapped_ti_count
or task.run_time_mapped_ti_count(self.run_id, session=session)
or 0
)
query = session.query(TaskInstance.map_index).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == self.run_id,
)
existing_indexes = {i for (i,) in query}
missing_indexes = set(range(total_length)).difference(existing_indexes)
removed_indexes = existing_indexes.difference(range(total_length))
created_tis = []

existing_indexes: Dict[MappedOperator, List[int]] = defaultdict(list)
new_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
for ti in tis:
task = ti.task
if not isinstance(task, MappedOperator):
continue
# skip unexpanded tasks and also tasks that expands with literal arguments
if ti.map_index < 0 or task.parse_time_mapped_ti_count:
continue
existing_indexes[task].append(ti.map_index)
task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0

if ti.map_index >= new_length:
self.log.debug(
"Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",
ti,
new_length,
)
ti.state = State.REMOVED
new_indexes[task] = range(new_length)
missing_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
for k, v in existing_indexes.items():
missing_indexes.update({k: list(set(new_indexes[k]).difference(v))})
return missing_indexes
if missing_indexes:
for index in missing_indexes:
ti = TaskInstance(task, run_id=self.run_id, map_index=index, state=None)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
ti.refresh_from_task(task)
session.flush()
created_tis.append(ti)
elif removed_indexes:
session.query(TaskInstance).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == self.run_id,
TaskInstance.map_index.in_(removed_indexes),
).update({TaskInstance.state: TaskInstanceState.REMOVED})
session.flush()
return created_tis

@staticmethod
def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']:
Expand Down
Loading

0 comments on commit db818ae

Please sign in to comment.