diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 9aa909f06cecf4..b2795b7bf9d786 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -227,7 +227,7 @@ def _get_ti( ) ti_or_none = dag_run.get_task_instance(task.task_id, map_index=map_index, session=session) - ti: TaskInstance | TaskInstancePydantic + ti: TaskInstance if ti_or_none is None: if not create_if_necessary: raise TaskInstanceNotFound( @@ -249,9 +249,7 @@ def _get_ti( return ti, dr_created -def _run_task_by_selected_method( - args, dag: DAG, ti: TaskInstance | TaskInstancePydantic -) -> None | TaskReturnCode: +def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None | TaskReturnCode: """ Run the task based on a mode. @@ -308,7 +306,7 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None: executor.end() -def _run_task_by_local_task_job(args, ti: TaskInstance | TaskInstancePydantic) -> TaskReturnCode | None: +def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | None: """Run LocalTaskJob, which monitors the raw task execution process.""" job_runner = LocalTaskJobRunner( job=Job(dag_id=ti.dag_id), @@ -354,7 +352,7 @@ def _extract_external_executor_id(args) -> str | None: @contextmanager -def _move_task_handlers_to_root(ti: TaskInstance | TaskInstancePydantic) -> Generator[None, None, None]: +def _move_task_handlers_to_root(ti: TaskInstance) -> Generator[None, None, None]: """ Move handlers for task logging to root logger. @@ -381,7 +379,7 @@ def _move_task_handlers_to_root(ti: TaskInstance | TaskInstancePydantic) -> Gene @contextmanager -def _redirect_stdout_to_ti_log(ti: TaskInstance | TaskInstancePydantic) -> Generator[None, None, None]: +def _redirect_stdout_to_ti_log(ti: TaskInstance) -> Generator[None, None, None]: """ Redirect stdout to ti logger. diff --git a/airflow/jobs/local_task_job_runner.py b/airflow/jobs/local_task_job_runner.py index 5d1f15fe8f8a51..e28d69bd1dcefc 100644 --- a/airflow/jobs/local_task_job_runner.py +++ b/airflow/jobs/local_task_job_runner.py @@ -42,7 +42,6 @@ from airflow.jobs.job import Job from airflow.models.taskinstance import TaskInstance - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic SIGSEGV_MESSAGE = """ ******************************************* Received SIGSEGV ******************************************* @@ -83,7 +82,7 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin): def __init__( self, job: Job, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, wait_for_past_depends_before_skipping: bool = False, diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py index b8ab93cf41a776..94b12a5c44f0df 100644 --- a/airflow/models/renderedtifields.py +++ b/airflow/models/renderedtifields.py @@ -49,7 +49,7 @@ from sqlalchemy.sql import FromClause from airflow.models import Operator - from airflow.models.taskinstance import TaskInstance, TaskInstancePydantic + from airflow.models.taskinstance import TaskInstance def get_serialized_template_fields(task: Operator): @@ -173,9 +173,7 @@ def _update_runtime_evaluated_template_fields( @classmethod @provide_session - def get_templated_fields( - cls, ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION - ) -> dict | None: + def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None: """ Get templated field for a TaskInstance from the RenderedTaskInstanceFields table. diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 87cd4b2d931e7b..c6cfb28959d5be 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -39,7 +39,6 @@ from airflow.models.operator import Operator from airflow.sdk.definitions.node import DAGNode from airflow.serialization.pydantic.dag_run import DagRunPydantic - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic # The key used by SkipMixin to store XCom data. XCOM_SKIPMIXIN_KEY = "skipmixin_key" @@ -153,7 +152,7 @@ def _skip( def skip_all_except( self, - ti: TaskInstance | TaskInstancePydantic, + ti: TaskInstance, branch_task_ids: None | str | Iterable[str], ): """Facade for compatibility for call to internal API.""" @@ -167,7 +166,7 @@ def skip_all_except( @provide_session def _skip_all_except( cls, - ti: TaskInstance | TaskInstancePydantic, + ti: TaskInstance, branch_task_ids: None | str | Iterable[str], session: Session = NEW_SESSION, ): diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 591f3549bab39f..d6b24f34000f4f 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -72,7 +72,6 @@ from sqlalchemy_utils import UUIDType from airflow import settings -from airflow.api_internal.internal_api_call import internal_api_call from airflow.assets.manager import asset_manager from airflow.configuration import conf from airflow.exceptions import ( @@ -165,7 +164,6 @@ from airflow.sdk.definitions.dag import DAG from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag import DagModelPydantic - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.timetables.base import DataInterval from airflow.typing_compat import Literal, TypeGuard from airflow.utils.task_group import TaskGroup @@ -185,14 +183,6 @@ class TaskReturnCode(Enum): """When task exits with deferral to trigger.""" -@internal_api_call -@provide_session -def _merge_ti(ti, session: Session = NEW_SESSION): - session.merge(ti) - session.commit() - - -@internal_api_call @provide_session def _add_log( event, @@ -215,14 +205,13 @@ def _add_log( ) -@internal_api_call @provide_session def _update_ti_heartbeat(id: str, when: datetime, session: Session = NEW_SESSION): session.execute(update(TaskInstance).where(TaskInstance.id == id).values(last_heartbeat_at=when)) def _run_raw_task( - ti: TaskInstance | TaskInstancePydantic, + ti: TaskInstance, mark_success: bool = False, test_mode: bool = False, pool: str | None = None, @@ -397,7 +386,7 @@ def set_current_context(context: Context) -> Generator[Context, None, None]: ) -def _stop_remaining_tasks(*, task_instance: TaskInstance | TaskInstancePydantic, session: Session): +def _stop_remaining_tasks(*, task_instance: TaskInstance, session: Session): """ Stop non-teardown tasks in dag. @@ -546,7 +535,6 @@ def clear_task_instances( session.flush() -@internal_api_call @provide_session def _xcom_pull( *, @@ -678,7 +666,7 @@ def _creator_note(val): return TaskInstanceNote(*val) -def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: Context, task_orig: Operator): +def _execute_task(task_instance: TaskInstance, context: Context, task_orig: Operator): """ Execute Task (optionally with a Timeout) and push Xcom results. @@ -843,7 +831,7 @@ def _set_ti_attrs(target, source, include_dag_run=False): def _refresh_from_db( *, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, session: Session | None = None, lock_for_update: bool = False, ) -> None: @@ -868,19 +856,12 @@ def _refresh_from_db( ) if ti: - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic - - include_dag_run = isinstance(ti, TaskInstancePydantic) - # in case of internal API, what we get is TaskInstancePydantic value, and we are supposed - # to also update dag_run information as it might not be available. We cannot always do it in - # case ti is TaskInstance, because it might be detached/ not loaded yet and dag_run might - # not be available. - _set_ti_attrs(task_instance, ti, include_dag_run=include_dag_run) + _set_ti_attrs(task_instance, ti, include_dag_run=False) else: task_instance.state = None -def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> None: +def _set_duration(*, task_instance: TaskInstance) -> None: """ Set task instance duration. @@ -895,7 +876,7 @@ def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> None log.debug("Task Duration set to %s", task_instance.duration) -def _stats_tags(*, task_instance: TaskInstance | TaskInstancePydantic) -> dict[str, str]: +def _stats_tags(*, task_instance: TaskInstance) -> dict[str, str]: """ Return task instance tags. @@ -906,7 +887,7 @@ def _stats_tags(*, task_instance: TaskInstance | TaskInstancePydantic) -> dict[s return prune_dict({"dag_id": task_instance.dag_id, "task_id": task_instance.task_id}) -def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydantic) -> None: +def _clear_next_method_args(*, task_instance: TaskInstance) -> None: """ Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them. @@ -920,10 +901,9 @@ def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydanti task_instance.next_kwargs = None -@internal_api_call def _get_template_context( *, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, dag: SchedulerDAG, session: Session | None = None, ignore_param_exceptions: bool = True, @@ -1072,7 +1052,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent | AssetEventPydantic]]: return Context(context) # type: ignore -def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic): +def _is_eligible_to_retry(*, task_instance: TaskInstance): """ Is task instance is eligible for retry. @@ -1095,10 +1075,9 @@ def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic) @provide_session -@internal_api_call def _handle_failure( *, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, error: None | str | BaseException, session: Session, test_mode: bool | None = None, @@ -1120,7 +1099,6 @@ def _handle_failure( """ if test_mode is None: test_mode = task_instance.test_mode - task_instance = _coalesce_to_orm_ti(ti=task_instance, session=session) failure_context = TaskInstance.fetch_handle_failure_context( ti=task_instance, # type: ignore[arg-type] error=error, @@ -1172,12 +1150,11 @@ def _handle_failure( "operator": str(task_instance.operator), } ) - if isinstance(task_instance, TaskInstance): - span.set_attribute("log_url", task_instance.log_url) + span.set_attribute("log_url", task_instance.log_url) def _refresh_from_task( - *, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, pool_override: str | None = None + *, task_instance: TaskInstance, task: Operator, pool_override: str | None = None ) -> None: """ Copy common attributes from the given task. @@ -1208,11 +1185,10 @@ def _refresh_from_task( task_instance_mutation_hook(task_instance) -@internal_api_call @provide_session def _record_task_map_for_downstreams( *, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, task: Operator, dag: DAG, value: Any, @@ -1257,7 +1233,7 @@ def _record_task_map_for_downstreams( def _get_previous_dagrun( *, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, state: DagRunState | None = None, session: Session | None = None, ) -> DagRun | None: @@ -1300,7 +1276,7 @@ def _get_previous_dagrun( def _get_previous_logical_date( *, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, state: DagRunState | None, session: Session, ) -> pendulum.DateTime | None: @@ -1320,7 +1296,7 @@ def _get_previous_logical_date( def _get_previous_start_date( *, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, state: DagRunState | None, session: Session, ) -> pendulum.DateTime | None: @@ -1337,9 +1313,7 @@ def _get_previous_start_date( return pendulum.instance(prev_ti.start_date) if prev_ti and prev_ti.start_date else None -def _email_alert( - *, task_instance: TaskInstance | TaskInstancePydantic, exception, task: BaseOperator -) -> None: +def _email_alert(*, task_instance: TaskInstance, exception, task: BaseOperator) -> None: """ Send alert email with exception information. @@ -1360,7 +1334,7 @@ def _email_alert( def _get_email_subject_content( *, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, exception: BaseException, task: BaseOperator | None = None, ) -> tuple[str, str, str]: @@ -1481,7 +1455,7 @@ def get_callback_representation(callback: TaskStateChangeCallback) -> Any: log.exception("Error in callback at index %d: %s", idx, callback_repr) -def _log_state(*, task_instance: TaskInstance | TaskInstancePydantic, lead_msg: str = "") -> None: +def _log_state(*, task_instance: TaskInstance, lead_msg: str = "") -> None: """ Log task state. @@ -1512,7 +1486,7 @@ def _log_state(*, task_instance: TaskInstance | TaskInstancePydantic, lead_msg: ) -def _date_or_empty(*, task_instance: TaskInstance | TaskInstancePydantic, attr: str) -> str: +def _date_or_empty(*, task_instance: TaskInstance, attr: str) -> str: """ Fetch a date attribute or None of it does not exist. @@ -1527,10 +1501,10 @@ def _date_or_empty(*, task_instance: TaskInstance | TaskInstancePydantic, attr: def _get_previous_ti( *, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, session: Session, state: DagRunState | None = None, -) -> TaskInstance | TaskInstancePydantic | None: +) -> TaskInstance | None: """ Get task instance for the task that ran before this task instance. @@ -1546,7 +1520,6 @@ def _get_previous_ti( return dagrun.get_task_instance(task_instance.task_id, session=session) -@internal_api_call @provide_session def _update_rtif(ti, rendered_fields, session: Session = NEW_SESSION): from airflow.models.renderedtifields import RenderedTaskInstanceFields @@ -1557,33 +1530,12 @@ def _update_rtif(ti, rendered_fields, session: Session = NEW_SESSION): RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session) -def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Session): - from airflow.models.dagrun import DagRun - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic - - if isinstance(ti, TaskInstancePydantic): - orm_ti = DagRun.fetch_task_instance( - dag_id=ti.dag_id, - dag_run_id=ti.run_id, - task_id=ti.task_id, - map_index=ti.map_index, - session=session, - ) - if TYPE_CHECKING: - assert orm_ti - ti, pydantic_ti = orm_ti, ti - _set_ti_attrs(ti, pydantic_ti) - ti.task = pydantic_ti.task - return ti - - -@internal_api_call @provide_session def _defer_task( - ti: TaskInstance | TaskInstancePydantic, + ti: TaskInstance, exception: TaskDeferred | None = None, session: Session = NEW_SESSION, -) -> TaskInstancePydantic | TaskInstance: +) -> TaskInstance: from airflow.models.trigger import Trigger if exception is not None: @@ -1614,8 +1566,6 @@ def _defer_task( session.add(trigger_row) session.flush() - ti = _coalesce_to_orm_ti(ti=ti, session=session) # ensure orm obj in case it's pydantic - if TYPE_CHECKING: assert ti.task @@ -1652,7 +1602,6 @@ def _defer_task( return ti -@internal_api_call @provide_session def _handle_reschedule( ti, @@ -1665,8 +1614,6 @@ def _handle_reschedule( if test_mode: return - ti = _coalesce_to_orm_ti(ti=ti, session=session) - ti.refresh_from_db(session) if TYPE_CHECKING: @@ -1932,7 +1879,7 @@ def task_display_name(self) -> str: @staticmethod def _command_as_list( - ti: TaskInstance | TaskInstancePydantic, + ti: TaskInstance, mark_success: bool = False, ignore_all_deps: bool = False, ignore_task_deps: bool = False, @@ -2142,7 +2089,6 @@ def error(self, session: Session = NEW_SESSION) -> None: session.commit() @classmethod - @internal_api_call @provide_session def get_task_instance( cls, @@ -2152,7 +2098,7 @@ def get_task_instance( map_index: int, lock_for_update: bool = False, session: Session = NEW_SESSION, - ) -> TaskInstance | TaskInstancePydantic | None: + ) -> TaskInstance | None: query = ( session.query(TaskInstance) .options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it @@ -2195,9 +2141,8 @@ def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> _refresh_from_task(task_instance=self, task=task, pool_override=pool_override) @staticmethod - @internal_api_call @provide_session - def _clear_xcom_data(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION) -> None: + def _clear_xcom_data(ti: TaskInstance, session: Session = NEW_SESSION) -> None: """ Clear all XCom data from the database for the task instance. @@ -2231,8 +2176,7 @@ def key(self) -> TaskInstanceKey: return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) @staticmethod - @internal_api_call - def _set_state(ti: TaskInstance | TaskInstancePydantic, state, session: Session) -> bool: + def _set_state(ti: TaskInstance, state, session: Session) -> bool: if not isinstance(ti, TaskInstance): ti = session.scalars( select(TaskInstance).where( @@ -2322,7 +2266,7 @@ def get_previous_ti( self, state: DagRunState | None = None, session: Session = NEW_SESSION, - ) -> TaskInstance | TaskInstancePydantic | None: + ) -> TaskInstance | None: """ Return the task instance for the task that ran before this task instance. @@ -2465,7 +2409,6 @@ def ready_for_retry(self) -> bool: return self.state == TaskInstanceState.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow() @staticmethod - @internal_api_call def _get_dagrun(dag_id, run_id, session) -> DagRun: from airflow.models.dagrun import DagRun # Avoid circular import @@ -2500,9 +2443,7 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: @classmethod @provide_session - def ensure_dag( - cls, task_instance: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION - ) -> DAG: + def ensure_dag(cls, task_instance: TaskInstance, session: Session = NEW_SESSION) -> DAG: """Ensure that task has a dag object associated, might have been removed by serialization.""" if TYPE_CHECKING: assert task_instance.task @@ -2515,11 +2456,10 @@ def ensure_dag( return task_instance.task.dag @classmethod - @internal_api_call @provide_session def _check_and_change_state_before_execution( cls, - task_instance: TaskInstance | TaskInstancePydantic, + task_instance: TaskInstance, verbose: bool = True, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, @@ -2556,14 +2496,7 @@ def _check_and_change_state_before_execution( if TYPE_CHECKING: assert task_instance.task - if isinstance(task_instance, TaskInstance): - ti: TaskInstance = task_instance - else: # isinstance(task_instance, TaskInstancePydantic) - filters = (col == getattr(task_instance, col.name) for col in inspect(TaskInstance).primary_key) - ti = session.query(TaskInstance).filter(*filters).scalar() - dag = DagBag(read_dags_from_db=True).get_dag(task_instance.dag_id, session=session) - task_instance.task = dag.task_dict[ti.task_id] - ti.task = task_instance.task + ti: TaskInstance = task_instance task = task_instance.task if TYPE_CHECKING: assert task @@ -2791,7 +2724,6 @@ def _register_asset_changes( TaskInstance._register_asset_changes_int(ti=self, events=events) @staticmethod - @internal_api_call @provide_session def _register_asset_changes_int( ti: TaskInstance, *, events: OutletEventAccessors, session: Session = NEW_SESSION @@ -3174,10 +3106,8 @@ def fetch_handle_failure_context( } @staticmethod - @internal_api_call @provide_session - def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION): - ti = _coalesce_to_orm_ti(ti=ti, session=session) + def save_to_db(ti: TaskInstance, session: Session = NEW_SESSION): ti.updated_at = timezone.utcnow() session.merge(ti) session.flush() diff --git a/airflow/models/taskinstancehistory.py b/airflow/models/taskinstancehistory.py index e587cf083e3b57..9ac11cad7dba5d 100644 --- a/airflow/models/taskinstancehistory.py +++ b/airflow/models/taskinstancehistory.py @@ -47,7 +47,6 @@ if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic class TaskInstanceHistory(Base): @@ -96,7 +95,7 @@ class TaskInstanceHistory(Base): def __init__( self, - ti: TaskInstance | TaskInstancePydantic, + ti: TaskInstance, state: str | None = None, ): super().__init__() diff --git a/airflow/models/taskmap.py b/airflow/models/taskmap.py index 478f09e0f1148c..2702b906df034d 100644 --- a/airflow/models/taskmap.py +++ b/airflow/models/taskmap.py @@ -31,7 +31,6 @@ if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic class TaskMapVariant(enum.Enum): @@ -98,7 +97,7 @@ def __init__( self.keys = keys @classmethod - def from_task_instance_xcom(cls, ti: TaskInstance | TaskInstancePydantic, value: Collection) -> TaskMap: + def from_task_instance_xcom(cls, ti: TaskInstance, value: Collection) -> TaskMap: if ti.run_id is None: raise ValueError("cannot record task map for unrun task instance") return cls( diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py index aa987294e8ee90..a60386d1ecb7cf 100644 --- a/airflow/models/taskreschedule.py +++ b/airflow/models/taskreschedule.py @@ -34,7 +34,6 @@ from sqlalchemy.sql import Select from airflow.models.taskinstance import TaskInstance - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic class TaskReschedule(TaskInstanceDependencies): @@ -101,7 +100,7 @@ def __init__( @classmethod def stmt_for_task_instance( cls, - ti: TaskInstance | TaskInstancePydantic, + ti: TaskInstance, *, try_number: int | None = None, descending: bool = False, diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py index 088aea23fd338f..81a82e9d12a082 100644 --- a/airflow/operators/branch.py +++ b/airflow/operators/branch.py @@ -27,7 +27,6 @@ if TYPE_CHECKING: from airflow.models import TaskInstance - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.utils.context import Context @@ -42,7 +41,7 @@ def do_branch(self, context: Context, branches_to_execute: str | Iterable[str]) return branches_to_execute def _expand_task_group_roots( - self, ti: TaskInstance | TaskInstancePydantic, branches_to_execute: str | Iterable[str] + self, ti: TaskInstance, branches_to_execute: str | Iterable[str] ) -> Iterable[str]: """Expand any task group into its root task ids.""" if TYPE_CHECKING: diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 50cc9e98edc234..a593a4519f11f8 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -45,9 +45,11 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.utils import timezone -from airflow.utils.session import create_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + from airflow.utils.context import Context # As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html. @@ -80,6 +82,30 @@ def __bool__(self) -> bool: return self.is_done +@provide_session +def _orig_start_date( + dag_id: str, task_id: str, run_id: str, map_index: int, try_number: int, session: Session = NEW_SESSION +): + """ + Get the original start_date for a rescheduled task. + + :meta private: + """ + return session.scalar( + select(TaskReschedule) + .where( + TaskReschedule.dag_id == dag_id, + TaskReschedule.task_id == task_id, + TaskReschedule.run_id == run_id, + TaskReschedule.map_index == map_index, + TaskReschedule.try_number == try_number, + ) + .order_by(TaskReschedule.id.asc()) + .with_only_columns(TaskReschedule.start_date) + .limit(1) + ) + + class BaseSensorOperator(BaseOperator, SkipMixin): """ Sensor operators are derived from this class and inherit these attributes. diff --git a/airflow/serialization/pydantic/dag_run.py b/airflow/serialization/pydantic/dag_run.py index e9409e3a8ac19a..a31f2c35927ae5 100644 --- a/airflow/serialization/pydantic/dag_run.py +++ b/airflow/serialization/pydantic/dag_run.py @@ -32,7 +32,6 @@ from sqlalchemy.orm import Session from airflow.jobs.scheduler_job_runner import TI - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.utils.state import TaskInstanceState @@ -89,7 +88,7 @@ def get_task_instance( session: Session, *, map_index: int = -1, - ) -> TI | TaskInstancePydantic | None: + ) -> TI | None: """ Return the task instance specified by task_id for this dag run. diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index 431903a8b9fce8..43bfd527a74278 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -28,32 +28,26 @@ PlainValidator, ) -from airflow.exceptions import AirflowRescheduleException, TaskDeferred +from airflow.exceptions import AirflowRescheduleException from airflow.models import Operator from airflow.models.baseoperator import BaseOperator from airflow.models.taskinstance import ( TaskInstance, - TaskReturnCode, - _defer_task, _handle_reschedule, - _run_raw_task, _set_ti_attrs, ) from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.net import get_hostname from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: - import pendulum from pydantic import ValidationInfo from sqlalchemy.orm import Session from airflow.models.dagrun import DagRun from airflow.utils.context import Context - from airflow.utils.state import DagRunState def serialize_operator(x: Operator | None) -> dict | None: @@ -65,7 +59,6 @@ def serialize_operator(x: Operator | None) -> dict | None: def validated_operator(x: dict[str, Any] | Operator, _info: ValidationInfo) -> Any: - from airflow.models.baseoperator import BaseOperator from airflow.models.mappedoperator import MappedOperator if isinstance(x, BaseOperator) or isinstance(x, MappedOperator) or x is None: @@ -132,29 +125,6 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin): def _logger_name(self): return "airflow.task" - def clear_xcom_data(self, session: Session | None = None): - TaskInstance._clear_xcom_data(ti=self, session=session) - - def set_state(self, state, session: Session | None = None) -> bool: - return TaskInstance._set_state(ti=self, state=state, session=session) - - def _run_raw_task( - self, - mark_success: bool = False, - test_mode: bool = False, - pool: str | None = None, - raise_on_defer: bool = False, - session: Session | None = None, - ) -> TaskReturnCode | None: - return _run_raw_task( - ti=self, - mark_success=mark_success, - test_mode=test_mode, - pool=pool, - raise_on_defer=raise_on_defer, - session=session, - ) - def _run_execute_callback(self, context, task): TaskInstance._run_execute_callback(self=self, context=context, task=task) # type: ignore[arg-type] @@ -241,274 +211,21 @@ def _execute_task(self, context, task_orig): return _execute_task(task_instance=self, context=context, task_orig=task_orig) - def refresh_from_db(self, session: Session | None = None, lock_for_update: bool = False) -> None: - """ - Refresh the task instance from the database based on the primary key. - - :param session: SQLAlchemy ORM Session - :param lock_for_update: if True, indicates that the database should - lock the TaskInstance (issuing a FOR UPDATE clause) until the - session is committed. - """ - from airflow.models.taskinstance import _refresh_from_db - - _refresh_from_db(task_instance=self, session=session, lock_for_update=lock_for_update) - def update_heartbeat(self): """Update the recorded heartbeat for this task to "now".""" from airflow.models.taskinstance import _update_ti_heartbeat return _update_ti_heartbeat(self.id, timezone.utcnow()) - def set_duration(self) -> None: - """Set task instance duration.""" - from airflow.models.taskinstance import _set_duration - - _set_duration(task_instance=self) - - @property - def stats_tags(self) -> dict[str, str]: - """Return task instance tags.""" - from airflow.models.taskinstance import _stats_tags - - return _stats_tags(task_instance=self) - - def clear_next_method_args(self) -> None: - """Ensure we unset next_method and next_kwargs to ensure that any retries don't reuse them.""" - from airflow.models.taskinstance import _clear_next_method_args - - _clear_next_method_args(task_instance=self) - - def get_template_context( - self, - session: Session | None = None, - ignore_param_exceptions: bool = True, - ) -> Context: - """ - Return TI Context. - - :param session: SQLAlchemy ORM Session - :param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict - """ - from airflow.models.taskinstance import _get_template_context - - if TYPE_CHECKING: - assert self.task - assert self.task.dag - return _get_template_context( - task_instance=self, - dag=self.task.dag, - session=session, - ignore_param_exceptions=ignore_param_exceptions, - ) - def is_eligible_to_retry(self): """Is task instance is eligible for retry.""" from airflow.models.taskinstance import _is_eligible_to_retry return _is_eligible_to_retry(task_instance=self) - def handle_failure( - self, - error: None | str | BaseException, - test_mode: bool | None = None, - context: Context | None = None, - force_fail: bool = False, - session: Session | None = None, - ) -> None: - """ - Handle Failure for a task instance. - - :param error: if specified, log the specific exception if thrown - :param session: SQLAlchemy ORM Session - :param test_mode: doesn't record success or failure in the DB if True - :param context: Jinja2 context - :param force_fail: if True, task does not retry - """ - from airflow.models.taskinstance import _handle_failure - - if TYPE_CHECKING: - assert self.task - assert self.task.dag - try: - fail_stop = self.task.dag.fail_stop - except Exception: - fail_stop = False - _handle_failure( - task_instance=self, - error=error, - session=session, - test_mode=test_mode, - context=context, - force_fail=force_fail, - fail_stop=fail_stop, - ) - - def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None: - """ - Copy common attributes from the given task. - - :param task: The task object to copy from - :param pool_override: Use the pool_override instead of task's pool - """ - from airflow.models.taskinstance import _refresh_from_task - - _refresh_from_task(task_instance=self, task=task, pool_override=pool_override) - - def get_previous_dagrun( - self, - state: DagRunState | None = None, - session: Session | None = None, - ) -> DagRun | None: - """ - Return the DagRun that ran before this task instance's DagRun. - - :param state: If passed, it only take into account instances of a specific state. - :param session: SQLAlchemy ORM Session. - """ - from airflow.models.taskinstance import _get_previous_dagrun - - return _get_previous_dagrun(task_instance=self, state=state, session=session) - - def get_previous_logical_date( - self, - state: DagRunState | None = None, - session: Session | None = None, - ) -> pendulum.DateTime | None: - """ - Return the logical date from property previous_ti_success. - - :param state: If passed, it only take into account instances of a specific state. - :param session: SQLAlchemy ORM Session - """ - from airflow.models.taskinstance import _get_previous_logical_date - - return _get_previous_logical_date(task_instance=self, state=state, session=session) - - def get_previous_start_date( - self, - state: DagRunState | None = None, - session: Session | None = None, - ) -> pendulum.DateTime | None: - """ - Return the logical date from property previous_ti_success. - - :param state: If passed, it only take into account instances of a specific state. - :param session: SQLAlchemy ORM Session - """ - from airflow.models.taskinstance import _get_previous_start_date - - return _get_previous_start_date(task_instance=self, state=state, session=session) - - def email_alert(self, exception, task: BaseOperator) -> None: - """ - Send alert email with exception information. - - :param exception: the exception - :param task: task related to the exception - """ - from airflow.models.taskinstance import _email_alert - - _email_alert(task_instance=self, exception=exception, task=task) - - def get_email_subject_content( - self, exception: BaseException, task: BaseOperator | None = None - ) -> tuple[str, str, str]: - """ - Get the email subject content for exceptions. - - :param exception: the exception sent in the email - :param task: - """ - from airflow.models.taskinstance import _get_email_subject_content - - return _get_email_subject_content(task_instance=self, exception=exception, task=task) - - def get_previous_ti( - self, - state: DagRunState | None = None, - session: Session | None = None, - ) -> TaskInstance | TaskInstancePydantic | None: - """ - Return the task instance for the task that ran before this task instance. - - :param session: SQLAlchemy ORM Session - :param state: If passed, it only take into account instances of a specific state. - """ - from airflow.models.taskinstance import _get_previous_ti - - return _get_previous_ti(task_instance=self, state=state, session=session) - - def check_and_change_state_before_execution( - self, - verbose: bool = True, - ignore_all_deps: bool = False, - ignore_depends_on_past: bool = False, - wait_for_past_depends_before_skipping: bool = False, - ignore_task_deps: bool = False, - ignore_ti_state: bool = False, - mark_success: bool = False, - test_mode: bool = False, - pool: str | None = None, - external_executor_id: str | None = None, - session: Session | None = None, - ) -> bool: - return TaskInstance._check_and_change_state_before_execution( - task_instance=self, - verbose=verbose, - ignore_all_deps=ignore_all_deps, - ignore_depends_on_past=ignore_depends_on_past, - wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, - ignore_task_deps=ignore_task_deps, - ignore_ti_state=ignore_ti_state, - mark_success=mark_success, - test_mode=test_mode, - hostname=get_hostname(), - pool=pool, - external_executor_id=external_executor_id, - session=session, - ) - - def command_as_list( - self, - mark_success: bool = False, - ignore_all_deps: bool = False, - ignore_task_deps: bool = False, - ignore_depends_on_past: bool = False, - wait_for_past_depends_before_skipping: bool = False, - ignore_ti_state: bool = False, - local: bool = False, - raw: bool = False, - pool: str | None = None, - cfg_path: str | None = None, - ) -> list[str]: - """ - Return a command that can be executed anywhere where airflow is installed. - - This command is part of the message sent to executors by the orchestrator. - """ - return TaskInstance._command_as_list( - ti=self, - mark_success=mark_success, - ignore_all_deps=ignore_all_deps, - ignore_task_deps=ignore_task_deps, - ignore_depends_on_past=ignore_depends_on_past, - wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, - ignore_ti_state=ignore_ti_state, - local=local, - raw=raw, - pool=pool, - cfg_path=cfg_path, - ) - def _register_asset_changes(self, *, events, session: Session | None = None) -> None: TaskInstance._register_asset_changes(self=self, events=events, session=session) # type: ignore[arg-type] - def defer_task(self, exception: TaskDeferred, session: Session | None = None): - """Defer task.""" - updated_ti = _defer_task(ti=self, exception=exception, session=session) - _set_ti_attrs(self, updated_ti) - def _handle_reschedule( self, actual_start_date: datetime, @@ -525,19 +242,5 @@ def _handle_reschedule( ) _set_ti_attrs(self, updated_ti) # _handle_reschedule is a remote call that mutates the TI - def get_relevant_upstream_map_indexes( - self, - upstream: Operator, - ti_count: int | None, - *, - session: Session | None = None, - ) -> int | range | None: - return TaskInstance.get_relevant_upstream_map_indexes( - self=self, # type: ignore[arg-type] - upstream=upstream, - ti_count=ti_count, - session=session, - ) - TaskInstancePydantic.model_rebuild() diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index d0ea2132f1c952..15511b375bd764 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -42,7 +42,6 @@ from airflow.models.taskinstance import TaskInstance from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic -from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.typing_compat import TypedDict KNOWN_CONTEXT_KEYS: set[str] @@ -118,11 +117,11 @@ class Context(TypedDict, total=False): reason: str | None run_id: str task: BaseOperator - task_instance: TaskInstance | TaskInstancePydantic + task_instance: TaskInstance task_instance_key_str: str test_mode: bool templates_dict: Mapping[str, Any] | None - ti: TaskInstance | TaskInstancePydantic + ti: TaskInstance triggering_asset_events: Mapping[str, Collection[AssetEvent | AssetEventPydantic]] ts: str ts_nodash: str diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index e8938dcf16426c..09866de7214ed6 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -46,7 +46,6 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic logger = logging.getLogger(__name__) @@ -137,7 +136,7 @@ def _interleave_logs(*logs): last = line -def _ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, session) -> TaskInstance: +def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance: """ Given TI | TIKey, return a TI object. diff --git a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py index 050cc296382d2c..f1023e91b1653d 100644 --- a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py +++ b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py @@ -31,7 +31,6 @@ if TYPE_CHECKING: from airflow.models import TaskInstance - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.utils.context import Context @@ -175,9 +174,7 @@ def _load_data_to_s3(self, data: dict) -> None: replace=self.s3_overwrite, ) - def _update_google_api_endpoint_params_via_xcom( - self, task_instance: TaskInstance | TaskInstancePydantic - ) -> None: + def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstance) -> None: if self.google_api_endpoint_params_via_xcom: google_api_endpoint_params = task_instance.xcom_pull( task_ids=self.google_api_endpoint_params_via_xcom_task_ids, @@ -185,9 +182,7 @@ def _update_google_api_endpoint_params_via_xcom( ) self.google_api_endpoint_params.update(google_api_endpoint_params) - def _expose_google_api_response_via_xcom( - self, task_instance: TaskInstance | TaskInstancePydantic, data: dict - ) -> None: + def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data: dict) -> None: if sys.getsizeof(data) < MAX_XCOM_SIZE: task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data) else: diff --git a/providers/src/airflow/providers/standard/operators/bash.py b/providers/src/airflow/providers/standard/operators/bash.py index 242909f885c227..5b086c8eb676d2 100644 --- a/providers/src/airflow/providers/standard/operators/bash.py +++ b/providers/src/airflow/providers/standard/operators/bash.py @@ -249,7 +249,7 @@ def execute(self, context: Context): # displays the executed command (otherwise it will display as an ArgNotSet type). if self._init_bash_command_not_set: is_inline_command = self._is_inline_command(bash_command=cast(str, self.bash_command)) - ti = cast("TaskInstance", context["ti"]) + ti = context["ti"] self.refresh_bash_command(ti) else: is_inline_command = self._is_inline_command(bash_command=cast(str, self._unrendered_bash_command))