Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed the ability for Operators to specify their own "scheduling deps". #45713

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class PluginResponse(BaseModel):
global_operator_extra_links: list[str]
operator_extra_links: list[str]
source: Annotated[str, BeforeValidator(coerce_to_string)]
ti_deps: list[Annotated[str, BeforeValidator(coerce_to_string)]]
listeners: list[str]
timetables: list[str]

Expand Down
6 changes: 0 additions & 6 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8699,11 +8699,6 @@ components:
source:
type: string
title: Source
ti_deps:
items:
type: string
type: array
title: Ti Deps
listeners:
items:
type: string
Expand All @@ -8725,7 +8720,6 @@ components:
- global_operator_extra_links
- operator_extra_links
- source
- ti_deps
- listeners
- timetables
title: PluginResponse
Expand Down
2 changes: 1 addition & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,14 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
partial_kwargs=partial_kwargs,
task_id=task_id,
params=partial_params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
template_fields=self.operator_class.template_fields,
template_fields_renderers=self.operator_class.template_fields_renderers,
ui_color=self.operator_class.ui_color,
ui_fgcolor=self.operator_class.ui_fgcolor,
is_empty=False,
is_sensor=self.operator_class._is_sensor,
task_module=self.operator_class.__module__,
task_type=self.operator_class.__name__,
operator_name=operator_name,
Expand Down
4 changes: 0 additions & 4 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,6 @@ def serialize(self):
)


class UnmappableOperator(AirflowException):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, did not even know such existed :-D Learning something new every day... ah, now need to have un-seen...

"""Raise when an operator is not implemented to be mappable."""


class XComForMappingNotPushed(AirflowException):
"""Raise when a mapped downstream's dependency fails to push XCom for task mapping."""

Expand Down
37 changes: 16 additions & 21 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Union

import attr
import attrs
import methodtools

from airflow.exceptions import UnmappableOperator
from airflow.models.abstractoperator import (
DEFAULT_EXECUTOR,
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
Expand All @@ -51,7 +50,6 @@
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.context import context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
Expand Down Expand Up @@ -140,7 +138,7 @@ def ensure_xcomarg_return_value(arg: Any) -> None:
ensure_xcomarg_return_value(v)


@attr.define(kw_only=True, repr=False)
@attrs.define(kw_only=True, repr=False)
class OperatorPartial:
"""
An "intermediate state" returned by ``BaseOperator.partial()``.
Expand Down Expand Up @@ -193,6 +191,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool =

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
from airflow.operators.empty import EmptyOperator
from airflow.sensors.base import BaseSensorOperator

self._expand_called = True
ensure_xcomarg_return_value(expand_input.value)
Expand All @@ -215,14 +214,14 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
partial_kwargs=partial_kwargs,
task_id=task_id,
params=self.params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
template_fields=self.operator_class.template_fields,
template_fields_renderers=self.operator_class.template_fields_renderers,
ui_color=self.operator_class.ui_color,
ui_fgcolor=self.operator_class.ui_fgcolor,
is_empty=issubclass(self.operator_class, EmptyOperator),
is_sensor=issubclass(self.operator_class, BaseSensorOperator),
task_module=self.operator_class.__module__,
task_type=self.operator_class.__name__,
operator_name=operator_name,
Expand All @@ -240,7 +239,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
return op


@attr.define(
@attrs.define(
kw_only=True,
# Disable custom __getstate__ and __setstate__ generation since it interacts
# badly with Airflow's DAG serialization and pickling. When a mapped task is
Expand All @@ -267,14 +266,15 @@ class MappedOperator(AbstractOperator):
# Needed for serialization.
task_id: str
params: ParamsDict | dict
deps: frozenset[BaseTIDep]
deps: frozenset[BaseTIDep] = attrs.field(init=False)
operator_extra_links: Collection[BaseOperatorLink]
template_ext: Sequence[str]
template_fields: Collection[str]
template_fields_renderers: dict[str, str]
ui_color: str
ui_fgcolor: str
_is_empty: bool
_is_sensor: bool = False
_task_module: str
_task_type: str
_operator_name: str
Expand All @@ -286,8 +286,8 @@ class MappedOperator(AbstractOperator):
task_group: TaskGroup | None
start_date: pendulum.DateTime | None
end_date: pendulum.DateTime | None
upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
downstream_task_ids: set[str] = attrs.field(factory=set, init=False)

_disallow_kwargs_override: bool
"""Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
Expand All @@ -308,6 +308,12 @@ class MappedOperator(AbstractOperator):
("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger")
)

@deps.default
def _deps(self):
from airflow.models.baseoperator import BaseOperator

return BaseOperator.deps

def __hash__(self):
return id(self)

Expand All @@ -333,7 +339,7 @@ def __attrs_post_init__(self):
@classmethod
def get_serialized_fields(cls):
# Not using 'cls' here since we only want to serialize base fields.
return (frozenset(attr.fields_dict(MappedOperator)) | {"task_type"}) - {
return (frozenset(attrs.fields_dict(MappedOperator)) | {"task_type"}) - {
"_task_type",
"dag",
"deps",
Expand All @@ -346,17 +352,6 @@ def get_serialized_fields(cls):
"_on_failure_fail_dagrun",
}

@methodtools.lru_cache(maxsize=None)
@staticmethod
def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]:
operator_deps = operator_class.deps
if not isinstance(operator_deps, collections.abc.Set):
raise UnmappableOperator(
f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "
f"not a {type(operator_deps).__name__}"
)
return operator_deps | {MappedTaskIsExpanded()}

@property
def task_type(self) -> str:
"""Implementing Operator."""
Expand Down
25 changes: 0 additions & 25 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
global_operator_extra_links: list[Any] | None = None
operator_extra_links: list[Any] | None = None
registered_operator_link_classes: dict[str, type] | None = None
registered_ti_dep_classes: dict[str, type] | None = None
timetable_classes: dict[str, type[Timetable]] | None = None
hook_lineage_reader_classes: list[type[HookLineageReader]] | None = None
priority_weight_strategy_classes: dict[str, type[PriorityWeightStrategy]] | None = None
Expand All @@ -95,7 +94,6 @@
"global_operator_extra_links",
"operator_extra_links",
"source",
"ti_deps",
"timetables",
"listeners",
"priority_weight_strategies",
Expand Down Expand Up @@ -171,8 +169,6 @@ class AirflowPlugin:
# buttons.
operator_extra_links: list[Any] = []

ti_deps: list[Any] = []

# A list of timetable classes that can be used for DAG scheduling.
timetables: list[type[Timetable]] = []

Expand Down Expand Up @@ -427,27 +423,6 @@ def initialize_fastapi_plugins():
fastapi_apps.extend(plugin.fastapi_apps)


def initialize_ti_deps_plugins():
"""Create modules for loaded extension from custom task instance dependency rule plugins."""
global registered_ti_dep_classes
if registered_ti_dep_classes is not None:
return

ensure_plugins_loaded()

if plugins is None:
raise AirflowPluginException("Can't load plugins.")

log.debug("Initialize custom taskinstance deps plugins")

registered_ti_dep_classes = {}

for plugin in plugins:
registered_ti_dep_classes.update(
{qualname(ti_dep.__class__): ti_dep.__class__ for ti_dep in plugin.ti_deps}
)


def initialize_extra_operators_links_plugins():
"""Create modules for loaded extension from extra operators links plugins."""
global global_operator_extra_links
Expand Down
4 changes: 3 additions & 1 deletion airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
ui_color: str = "#e6f1f2"
valid_modes: Iterable[str] = ["poke", "reschedule"]

_is_sensor: bool = True

# Adds one additional dependency for all sensor operators that checks if a
# sensor task instance can be rescheduled.
deps = BaseOperator.deps | {ReadyToRescheduleDep()}
Expand Down Expand Up @@ -406,7 +408,7 @@ def reschedule(self):

@classmethod
def get_serialized_fields(cls):
return super().get_serialized_fields() | {"reschedule"}
return super().get_serialized_fields() | {"reschedule", "_is_sensor"}


def poke_mode_only(cls):
Expand Down
7 changes: 1 addition & 6 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,6 @@
"items": { "type": "string" }
},
"_is_dummy": { "type": "boolean" },
"deps": {
"description": "list of dep classes -- if non-standard",
"type": "array",
"items": { "type": "string" },
"uniqueItems": true
},
"doc": { "type": "string" },
"doc_md": { "type": "string" },
"doc_json": { "type": "string" },
Expand All @@ -293,6 +287,7 @@
"_logger_name": { "type": "string" },
"_log_config_logger_name": { "type": "string" },
"_is_mapped": { "const": true, "$comment": "only present when True" },
"_is_sensor": { "const": true, "$comment": "only present when True" },
"expand_input": { "type": "object" },
"partial_kwargs": { "type": "object" },
"map_index_template": { "type": "string" },
Expand Down
Loading
Loading