Skip to content

Commit 2170af7

Browse files
committed
AIP-72: Pass context keys from API Server to Workers
Part of #44481
1 parent aaf29ee commit 2170af7

File tree

11 files changed

+382
-72
lines changed

11 files changed

+382
-72
lines changed

airflow/api_fastapi/execution_api/datamodels/taskinstance.py

+34
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525

2626
from airflow.api_fastapi.common.types import UtcDateTime
2727
from airflow.api_fastapi.core_api.base import BaseModel
28+
from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
29+
from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse
2830
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
31+
from airflow.utils.types import DagRunType
2932

3033

3134
class TIEnterRunningPayload(BaseModel):
@@ -135,3 +138,34 @@ class TaskInstance(BaseModel):
135138
run_id: str
136139
try_number: int
137140
map_index: int | None = None
141+
142+
143+
class DagRun(BaseModel):
144+
"""Schema for DagRun model with minimal required fields needed for Runtime."""
145+
146+
# TODO: `dag_id` and `run_id` are duplicated from TaskInstance
147+
# See if we can avoid sending these fields from API server and instead
148+
# use the TaskInstance data to get the DAG run information in the client (Task Execution Interface).
149+
dag_id: str
150+
run_id: str
151+
152+
logical_date: UtcDateTime
153+
data_interval_start: UtcDateTime | None
154+
data_interval_end: UtcDateTime | None
155+
start_date: UtcDateTime
156+
end_date: UtcDateTime | None
157+
run_type: DagRunType
158+
conf: Annotated[dict[str, Any], Field(default_factory=dict)]
159+
160+
161+
class TIRunContext(BaseModel):
162+
"""Response schema for TaskInstance run context."""
163+
164+
dag_run: DagRun
165+
"""DAG run information for the task instance."""
166+
167+
variables: Annotated[list[VariableResponse], Field(default_factory=list)]
168+
"""Variables that can be accessed by the task instance."""
169+
170+
connections: Annotated[list[ConnectionResponse], Field(default_factory=list)]
171+
"""Connections that can be accessed by the task instance."""

airflow/api_fastapi/execution_api/routes/task_instances.py

+113-28
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
from airflow.api_fastapi.common.db.common import SessionDep
3131
from airflow.api_fastapi.common.router import AirflowRouter
3232
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
33+
DagRun,
3334
TIDeferredStatePayload,
3435
TIEnterRunningPayload,
3536
TIHeartbeatInfo,
37+
TIRunContext,
3638
TIStateUpdate,
3739
TITerminalStatePayload,
3840
)
41+
from airflow.models.dagrun import DagRun as DR
3942
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
4043
from airflow.models.trigger import Trigger
4144
from airflow.utils import timezone
@@ -48,6 +51,108 @@
4851
log = logging.getLogger(__name__)
4952

5053

54+
@router.patch(
55+
"/{task_instance_id}/run",
56+
status_code=status.HTTP_200_OK,
57+
responses={
58+
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
59+
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
60+
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"},
61+
},
62+
)
63+
def ti_run(
64+
task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep
65+
) -> TIRunContext:
66+
"""
67+
Run a TaskInstance.
68+
69+
This endpoint is used to start a TaskInstance that is in the QUEUED state.
70+
"""
71+
# We only use UUID above for validation purposes
72+
ti_id_str = str(task_instance_id)
73+
74+
old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update()
75+
try:
76+
(previous_state, dag_id, run_id) = session.execute(old).one()
77+
except NoResultFound:
78+
log.error("Task Instance %s not found", ti_id_str)
79+
raise HTTPException(
80+
status_code=status.HTTP_404_NOT_FOUND,
81+
detail={
82+
"reason": "not_found",
83+
"message": "Task Instance not found",
84+
},
85+
)
86+
87+
# We exclude_unset to avoid updating fields that are not set in the payload
88+
data = ti_run_payload.model_dump(exclude_unset=True)
89+
90+
query = update(TI).where(TI.id == ti_id_str).values(data)
91+
92+
if previous_state != State.QUEUED:
93+
log.warning(
94+
"Can not start Task Instance ('%s') in invalid state: %s",
95+
ti_id_str,
96+
previous_state,
97+
)
98+
99+
# TODO: Pass a RFC 9457 compliant error message in "detail" field
100+
# https://datatracker.ietf.org/doc/html/rfc9457
101+
# to provide more information about the error
102+
# FastAPI will automatically convert this to a JSON response
103+
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370
104+
raise HTTPException(
105+
status_code=status.HTTP_409_CONFLICT,
106+
detail={
107+
"reason": "invalid_state",
108+
"message": "TI was not in a state where it could be marked as running",
109+
"previous_state": previous_state,
110+
},
111+
)
112+
log.info("Task with %s state started on %s ", previous_state, ti_run_payload.hostname)
113+
# Ensure there is no end date set.
114+
query = query.values(
115+
end_date=None,
116+
hostname=ti_run_payload.hostname,
117+
unixname=ti_run_payload.unixname,
118+
pid=ti_run_payload.pid,
119+
state=State.RUNNING,
120+
)
121+
122+
try:
123+
result = session.execute(query)
124+
log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount)
125+
126+
dr = session.execute(
127+
select(
128+
DR.run_id,
129+
DR.dag_id,
130+
DR.data_interval_start,
131+
DR.data_interval_end,
132+
DR.start_date,
133+
DR.end_date,
134+
DR.run_type,
135+
DR.conf,
136+
DR.logical_date,
137+
).filter_by(dag_id=dag_id, run_id=run_id)
138+
).one_or_none()
139+
140+
if not dr:
141+
raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} not found.")
142+
143+
return TIRunContext(
144+
dag_run=DagRun.model_validate(dr, from_attributes=True),
145+
# TODO: Add variables and connections that are needed (and has perms) for the task
146+
variables=[],
147+
connections=[],
148+
)
149+
except SQLAlchemyError as e:
150+
log.error("Error marking Task Instance state as running: %s", e)
151+
raise HTTPException(
152+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
153+
)
154+
155+
51156
@router.patch(
52157
"/{task_instance_id}/state",
53158
status_code=status.HTTP_204_NO_CONTENT,
@@ -92,35 +197,15 @@ def ti_update_state(
92197

93198
query = update(TI).where(TI.id == ti_id_str).values(data)
94199

200+
# TODO: Instead remove this payload from discriminator accepted by this endpoint
95201
if isinstance(ti_patch_payload, TIEnterRunningPayload):
96-
if previous_state != State.QUEUED:
97-
log.warning(
98-
"Can not start Task Instance ('%s') in invalid state: %s",
99-
ti_id_str,
100-
previous_state,
101-
)
102-
103-
# TODO: Pass a RFC 9457 compliant error message in "detail" field
104-
# https://datatracker.ietf.org/doc/html/rfc9457
105-
# to provide more information about the error
106-
# FastAPI will automatically convert this to a JSON response
107-
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370
108-
raise HTTPException(
109-
status_code=status.HTTP_409_CONFLICT,
110-
detail={
111-
"reason": "invalid_state",
112-
"message": "TI was not in a state where it could be marked as running",
113-
"previous_state": previous_state,
114-
},
115-
)
116-
log.info("Task with %s state started on %s ", previous_state, ti_patch_payload.hostname)
117-
# Ensure there is no end date set.
118-
query = query.values(
119-
end_date=None,
120-
hostname=ti_patch_payload.hostname,
121-
unixname=ti_patch_payload.unixname,
122-
pid=ti_patch_payload.pid,
123-
state=State.RUNNING,
202+
raise HTTPException(
203+
status_code=status.HTTP_409_CONFLICT,
204+
detail={
205+
"reason": "invalid_state",
206+
"message": "TI should be started using the /run endpoint",
207+
"previous_state": previous_state,
208+
},
124209
)
125210
elif isinstance(ti_patch_payload, TITerminalStatePayload):
126211
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)

task_sdk/src/airflow/sdk/api/client.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@
3030
from airflow.sdk import __version__
3131
from airflow.sdk.api.datamodels._generated import (
3232
ConnectionResponse,
33+
DagRunType,
3334
TerminalTIState,
3435
TIDeferredStatePayload,
3536
TIEnterRunningPayload,
3637
TIHeartbeatInfo,
38+
TIRunContext,
3739
TITerminalStatePayload,
3840
ValidationError as RemoteValidationError,
3941
VariablePostBody,
@@ -110,11 +112,12 @@ class TaskInstanceOperations:
110112
def __init__(self, client: Client):
111113
self.client = client
112114

113-
def start(self, id: uuid.UUID, pid: int, when: datetime):
115+
def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
114116
"""Tell the API server that this TI has started running."""
115117
body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when)
116118

117-
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
119+
resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
120+
return TIRunContext.model_validate_json(resp.read())
118121

119122
def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
120123
"""Tell the API server that this TI has reached a terminal state."""
@@ -218,7 +221,23 @@ def auth_flow(self, request: httpx.Request):
218221
# This exists as a aid for debugging or local running via the `dry_run` argument to Client. It doesn't make
219222
# sense for returning connections etc.
220223
def noop_handler(request: httpx.Request) -> httpx.Response:
221-
log.debug("Dry-run request", method=request.method, path=request.url.path)
224+
path = request.url.path
225+
log.debug("Dry-run request", method=request.method, path=path)
226+
227+
if path.startswith("/task-instances/") and path.endswith("/run"):
228+
# Return a fake context
229+
return httpx.Response(
230+
200,
231+
json={
232+
"dag_run": {
233+
"dag_id": "test_dag",
234+
"run_id": "test_run",
235+
"logical_date": "2021-01-01T00:00:00Z",
236+
"start_date": "2021-01-01T00:00:00Z",
237+
"run_type": DagRunType.MANUAL,
238+
},
239+
},
240+
)
222241
return httpx.Response(200, json={"text": "Hello, world!"})
223242

224243

task_sdk/src/airflow/sdk/api/datamodels/_generated.py

+37
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ class ConnectionResponse(BaseModel):
4444
extra: Annotated[str | None, Field(title="Extra")] = None
4545

4646

47+
class DagRunType(str, Enum):
48+
"""
49+
Class with DagRun types.
50+
"""
51+
52+
BACKFILL = "backfill"
53+
SCHEDULED = "scheduled"
54+
MANUAL = "manual"
55+
ASSET_TRIGGERED = "asset_triggered"
56+
57+
4758
class IntermediateTIState(str, Enum):
4859
"""
4960
States that a Task Instance can be in that indicate it is not yet in a terminal or running state.
@@ -159,10 +170,36 @@ class TaskInstance(BaseModel):
159170
map_index: Annotated[int | None, Field(title="Map Index")] = None
160171

161172

173+
class DagRun(BaseModel):
174+
"""
175+
Schema for DagRun model with minimal required fields needed for Runtime.
176+
"""
177+
178+
dag_id: Annotated[str, Field(title="Dag Id")]
179+
run_id: Annotated[str, Field(title="Run Id")]
180+
logical_date: Annotated[datetime, Field(title="Logical Date")]
181+
data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None
182+
data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None
183+
start_date: Annotated[datetime, Field(title="Start Date")]
184+
end_date: Annotated[datetime | None, Field(title="End Date")] = None
185+
run_type: DagRunType
186+
conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None
187+
188+
162189
class HTTPValidationError(BaseModel):
163190
detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None
164191

165192

193+
class TIRunContext(BaseModel):
194+
"""
195+
Response schema for TaskInstance run context.
196+
"""
197+
198+
dag_run: DagRun
199+
variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None
200+
connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None
201+
202+
166203
class TITerminalStatePayload(BaseModel):
167204
"""
168205
Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).

task_sdk/src/airflow/sdk/execution_time/comms.py

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
TaskInstance,
5555
TerminalTIState,
5656
TIDeferredStatePayload,
57+
TIRunContext,
5758
VariableResponse,
5859
XComResponse,
5960
)
@@ -70,6 +71,7 @@ class StartupDetails(BaseModel):
7071
7172
Responses will come back on stdin
7273
"""
74+
ti_context: TIRunContext | None = None
7375
type: Literal["StartupDetails"] = "StartupDetails"
7476

7577

task_sdk/src/airflow/sdk/execution_time/supervisor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ
397397
# We've forked, but the task won't start doing anything until we send it the StartupDetails
398398
# message. But before we do that, we need to tell the server it's started (so it has the chance to
399399
# tell us "no, stop!" for any reason)
400-
self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc))
400+
ti_context = self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc))
401401
self._last_successful_heartbeat = time.monotonic()
402402
except Exception:
403403
# On any error kill that subprocess!
@@ -408,6 +408,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ
408408
ti=ti,
409409
file=os.fspath(path),
410410
requests_fd=requests_fd,
411+
ti_context=ti_context,
411412
)
412413

413414
# Send the message to tell the process what it needs to execute

0 commit comments

Comments
 (0)