|
30 | 30 | from airflow.api_fastapi.common.db.common import SessionDep
|
31 | 31 | from airflow.api_fastapi.common.router import AirflowRouter
|
32 | 32 | from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
|
| 33 | + DagRun, |
33 | 34 | TIDeferredStatePayload,
|
34 | 35 | TIEnterRunningPayload,
|
35 | 36 | TIHeartbeatInfo,
|
| 37 | + TIRunContext, |
36 | 38 | TIStateUpdate,
|
37 | 39 | TITerminalStatePayload,
|
38 | 40 | )
|
| 41 | +from airflow.models.dagrun import DagRun as DR |
39 | 42 | from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
|
40 | 43 | from airflow.models.trigger import Trigger
|
41 | 44 | from airflow.utils import timezone
|
|
48 | 51 | log = logging.getLogger(__name__)
|
49 | 52 |
|
50 | 53 |
|
| 54 | +@router.patch( |
| 55 | + "/{task_instance_id}/run", |
| 56 | + status_code=status.HTTP_200_OK, |
| 57 | + responses={ |
| 58 | + status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, |
| 59 | + status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, |
| 60 | + status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, |
| 61 | + }, |
| 62 | +) |
| 63 | +def ti_run( |
| 64 | + task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep |
| 65 | +) -> TIRunContext: |
| 66 | + """ |
| 67 | + Run a TaskInstance. |
| 68 | +
|
| 69 | + This endpoint is used to start a TaskInstance that is in the QUEUED state. |
| 70 | + """ |
| 71 | + # We only use UUID above for validation purposes |
| 72 | + ti_id_str = str(task_instance_id) |
| 73 | + |
| 74 | + old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update() |
| 75 | + try: |
| 76 | + (previous_state, dag_id, run_id) = session.execute(old).one() |
| 77 | + except NoResultFound: |
| 78 | + log.error("Task Instance %s not found", ti_id_str) |
| 79 | + raise HTTPException( |
| 80 | + status_code=status.HTTP_404_NOT_FOUND, |
| 81 | + detail={ |
| 82 | + "reason": "not_found", |
| 83 | + "message": "Task Instance not found", |
| 84 | + }, |
| 85 | + ) |
| 86 | + |
| 87 | + # We exclude_unset to avoid updating fields that are not set in the payload |
| 88 | + data = ti_run_payload.model_dump(exclude_unset=True) |
| 89 | + |
| 90 | + query = update(TI).where(TI.id == ti_id_str).values(data) |
| 91 | + |
| 92 | + # TODO: We will need to change this for other states like: |
| 93 | + # reschedule, retry, defer etc. |
| 94 | + if previous_state != State.QUEUED: |
| 95 | + log.warning( |
| 96 | + "Can not start Task Instance ('%s') in invalid state: %s", |
| 97 | + ti_id_str, |
| 98 | + previous_state, |
| 99 | + ) |
| 100 | + |
| 101 | + # TODO: Pass a RFC 9457 compliant error message in "detail" field |
| 102 | + # https://datatracker.ietf.org/doc/html/rfc9457 |
| 103 | + # to provide more information about the error |
| 104 | + # FastAPI will automatically convert this to a JSON response |
| 105 | + # This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370 |
| 106 | + raise HTTPException( |
| 107 | + status_code=status.HTTP_409_CONFLICT, |
| 108 | + detail={ |
| 109 | + "reason": "invalid_state", |
| 110 | + "message": "TI was not in a state where it could be marked as running", |
| 111 | + "previous_state": previous_state, |
| 112 | + }, |
| 113 | + ) |
| 114 | + log.info("Task with %s state started on %s ", previous_state, ti_run_payload.hostname) |
| 115 | + # Ensure there is no end date set. |
| 116 | + query = query.values( |
| 117 | + end_date=None, |
| 118 | + hostname=ti_run_payload.hostname, |
| 119 | + unixname=ti_run_payload.unixname, |
| 120 | + pid=ti_run_payload.pid, |
| 121 | + state=State.RUNNING, |
| 122 | + ) |
| 123 | + |
| 124 | + try: |
| 125 | + result = session.execute(query) |
| 126 | + log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount) |
| 127 | + |
| 128 | + dr = session.execute( |
| 129 | + select( |
| 130 | + DR.run_id, |
| 131 | + DR.dag_id, |
| 132 | + DR.data_interval_start, |
| 133 | + DR.data_interval_end, |
| 134 | + DR.start_date, |
| 135 | + DR.end_date, |
| 136 | + DR.run_type, |
| 137 | + DR.conf, |
| 138 | + DR.logical_date, |
| 139 | + ).filter_by(dag_id=dag_id, run_id=run_id) |
| 140 | + ).one_or_none() |
| 141 | + |
| 142 | + if not dr: |
| 143 | + raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} not found.") |
| 144 | + |
| 145 | + return TIRunContext( |
| 146 | + dag_run=DagRun.model_validate(dr, from_attributes=True), |
| 147 | + # TODO: Add variables and connections that are needed (and has perms) for the task |
| 148 | + variables=[], |
| 149 | + connections=[], |
| 150 | + ) |
| 151 | + except SQLAlchemyError as e: |
| 152 | + log.error("Error marking Task Instance state as running: %s", e) |
| 153 | + raise HTTPException( |
| 154 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred" |
| 155 | + ) |
| 156 | + |
| 157 | + |
51 | 158 | @router.patch(
|
52 | 159 | "/{task_instance_id}/state",
|
53 | 160 | status_code=status.HTTP_204_NO_CONTENT,
|
@@ -92,37 +199,7 @@ def ti_update_state(
|
92 | 199 |
|
93 | 200 | query = update(TI).where(TI.id == ti_id_str).values(data)
|
94 | 201 |
|
95 |
| - if isinstance(ti_patch_payload, TIEnterRunningPayload): |
96 |
| - if previous_state != State.QUEUED: |
97 |
| - log.warning( |
98 |
| - "Can not start Task Instance ('%s') in invalid state: %s", |
99 |
| - ti_id_str, |
100 |
| - previous_state, |
101 |
| - ) |
102 |
| - |
103 |
| - # TODO: Pass a RFC 9457 compliant error message in "detail" field |
104 |
| - # https://datatracker.ietf.org/doc/html/rfc9457 |
105 |
| - # to provide more information about the error |
106 |
| - # FastAPI will automatically convert this to a JSON response |
107 |
| - # This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370 |
108 |
| - raise HTTPException( |
109 |
| - status_code=status.HTTP_409_CONFLICT, |
110 |
| - detail={ |
111 |
| - "reason": "invalid_state", |
112 |
| - "message": "TI was not in a state where it could be marked as running", |
113 |
| - "previous_state": previous_state, |
114 |
| - }, |
115 |
| - ) |
116 |
| - log.info("Task with %s state started on %s ", previous_state, ti_patch_payload.hostname) |
117 |
| - # Ensure there is no end date set. |
118 |
| - query = query.values( |
119 |
| - end_date=None, |
120 |
| - hostname=ti_patch_payload.hostname, |
121 |
| - unixname=ti_patch_payload.unixname, |
122 |
| - pid=ti_patch_payload.pid, |
123 |
| - state=State.RUNNING, |
124 |
| - ) |
125 |
| - elif isinstance(ti_patch_payload, TITerminalStatePayload): |
| 202 | + if isinstance(ti_patch_payload, TITerminalStatePayload): |
126 | 203 | query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
|
127 | 204 | elif isinstance(ti_patch_payload, TIDeferredStatePayload):
|
128 | 205 | # Calculate timeout if it was passed
|
|
0 commit comments