Skip to content

Commit dbff6e3

Browse files
authored
AIP-72: Pass context keys from API Server to Workers (#44899)
Part of #44481 This commit augments the TI context available in the Task Execution Interface with the one from the Execution API Server. In future PRs the following will be added: - More methods on TI like ti.xcom_pull, ti.xcom_push etc - Lazy fetching of connections, variables - Verifying the "get_current_context" is working
1 parent 4b38bed commit dbff6e3

File tree

12 files changed

+506
-91
lines changed

12 files changed

+506
-91
lines changed

airflow/api_fastapi/execution_api/datamodels/taskinstance.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525

2626
from airflow.api_fastapi.common.types import UtcDateTime
2727
from airflow.api_fastapi.core_api.base import BaseModel
28+
from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
29+
from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse
2830
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
31+
from airflow.utils.types import DagRunType
2932

3033

3134
class TIEnterRunningPayload(BaseModel):
@@ -94,9 +97,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
9497
state = v.get("state")
9598
else:
9699
state = getattr(v, "state", None)
97-
if state == TIState.RUNNING:
98-
return str(state)
99-
elif state in set(TerminalTIState):
100+
if state in set(TerminalTIState):
100101
return "_terminal_"
101102
elif state == TIState.DEFERRED:
102103
return "deferred"
@@ -107,7 +108,6 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
107108
# and "_other_" is a catch-all for all other states that are not covered by the other schemas.
108109
TIStateUpdate = Annotated[
109110
Union[
110-
Annotated[TIEnterRunningPayload, Tag("running")],
111111
Annotated[TITerminalStatePayload, Tag("_terminal_")],
112112
Annotated[TITargetStatePayload, Tag("_other_")],
113113
Annotated[TIDeferredStatePayload, Tag("deferred")],
@@ -135,3 +135,34 @@ class TaskInstance(BaseModel):
135135
run_id: str
136136
try_number: int
137137
map_index: int | None = None
138+
139+
140+
class DagRun(BaseModel):
141+
"""Schema for DagRun model with minimal required fields needed for Runtime."""
142+
143+
# TODO: `dag_id` and `run_id` are duplicated from TaskInstance
144+
# See if we can avoid sending these fields from API server and instead
145+
# use the TaskInstance data to get the DAG run information in the client (Task Execution Interface).
146+
dag_id: str
147+
run_id: str
148+
149+
logical_date: UtcDateTime
150+
data_interval_start: UtcDateTime | None
151+
data_interval_end: UtcDateTime | None
152+
start_date: UtcDateTime
153+
end_date: UtcDateTime | None
154+
run_type: DagRunType
155+
conf: Annotated[dict[str, Any], Field(default_factory=dict)]
156+
157+
158+
class TIRunContext(BaseModel):
159+
"""Response schema for TaskInstance run context."""
160+
161+
dag_run: DagRun
162+
"""DAG run information for the task instance."""
163+
164+
variables: Annotated[list[VariableResponse], Field(default_factory=list)]
165+
"""Variables that can be accessed by the task instance."""
166+
167+
connections: Annotated[list[ConnectionResponse], Field(default_factory=list)]
168+
"""Connections that can be accessed by the task instance."""

airflow/api_fastapi/execution_api/routes/task_instances.py

+108-31
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@
3030
from airflow.api_fastapi.common.db.common import SessionDep
3131
from airflow.api_fastapi.common.router import AirflowRouter
3232
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
33+
DagRun,
3334
TIDeferredStatePayload,
3435
TIEnterRunningPayload,
3536
TIHeartbeatInfo,
37+
TIRunContext,
3638
TIStateUpdate,
3739
TITerminalStatePayload,
3840
)
41+
from airflow.models.dagrun import DagRun as DR
3942
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
4043
from airflow.models.trigger import Trigger
4144
from airflow.utils import timezone
@@ -48,6 +51,110 @@
4851
log = logging.getLogger(__name__)
4952

5053

54+
@router.patch(
55+
"/{task_instance_id}/run",
56+
status_code=status.HTTP_200_OK,
57+
responses={
58+
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
59+
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
60+
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"},
61+
},
62+
)
63+
def ti_run(
64+
task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep
65+
) -> TIRunContext:
66+
"""
67+
Run a TaskInstance.
68+
69+
This endpoint is used to start a TaskInstance that is in the QUEUED state.
70+
"""
71+
# We only use UUID above for validation purposes
72+
ti_id_str = str(task_instance_id)
73+
74+
old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update()
75+
try:
76+
(previous_state, dag_id, run_id) = session.execute(old).one()
77+
except NoResultFound:
78+
log.error("Task Instance %s not found", ti_id_str)
79+
raise HTTPException(
80+
status_code=status.HTTP_404_NOT_FOUND,
81+
detail={
82+
"reason": "not_found",
83+
"message": "Task Instance not found",
84+
},
85+
)
86+
87+
# We exclude_unset to avoid updating fields that are not set in the payload
88+
data = ti_run_payload.model_dump(exclude_unset=True)
89+
90+
query = update(TI).where(TI.id == ti_id_str).values(data)
91+
92+
# TODO: We will need to change this for other states like:
93+
# reschedule, retry, defer etc.
94+
if previous_state != State.QUEUED:
95+
log.warning(
96+
"Can not start Task Instance ('%s') in invalid state: %s",
97+
ti_id_str,
98+
previous_state,
99+
)
100+
101+
# TODO: Pass a RFC 9457 compliant error message in "detail" field
102+
# https://datatracker.ietf.org/doc/html/rfc9457
103+
# to provide more information about the error
104+
# FastAPI will automatically convert this to a JSON response
105+
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370
106+
raise HTTPException(
107+
status_code=status.HTTP_409_CONFLICT,
108+
detail={
109+
"reason": "invalid_state",
110+
"message": "TI was not in a state where it could be marked as running",
111+
"previous_state": previous_state,
112+
},
113+
)
114+
log.info("Task with %s state started on %s ", previous_state, ti_run_payload.hostname)
115+
# Ensure there is no end date set.
116+
query = query.values(
117+
end_date=None,
118+
hostname=ti_run_payload.hostname,
119+
unixname=ti_run_payload.unixname,
120+
pid=ti_run_payload.pid,
121+
state=State.RUNNING,
122+
)
123+
124+
try:
125+
result = session.execute(query)
126+
log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount)
127+
128+
dr = session.execute(
129+
select(
130+
DR.run_id,
131+
DR.dag_id,
132+
DR.data_interval_start,
133+
DR.data_interval_end,
134+
DR.start_date,
135+
DR.end_date,
136+
DR.run_type,
137+
DR.conf,
138+
DR.logical_date,
139+
).filter_by(dag_id=dag_id, run_id=run_id)
140+
).one_or_none()
141+
142+
if not dr:
143+
raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} not found.")
144+
145+
return TIRunContext(
146+
dag_run=DagRun.model_validate(dr, from_attributes=True),
147+
# TODO: Add variables and connections that are needed (and has perms) for the task
148+
variables=[],
149+
connections=[],
150+
)
151+
except SQLAlchemyError as e:
152+
log.error("Error marking Task Instance state as running: %s", e)
153+
raise HTTPException(
154+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
155+
)
156+
157+
51158
@router.patch(
52159
"/{task_instance_id}/state",
53160
status_code=status.HTTP_204_NO_CONTENT,
@@ -92,37 +199,7 @@ def ti_update_state(
92199

93200
query = update(TI).where(TI.id == ti_id_str).values(data)
94201

95-
if isinstance(ti_patch_payload, TIEnterRunningPayload):
96-
if previous_state != State.QUEUED:
97-
log.warning(
98-
"Can not start Task Instance ('%s') in invalid state: %s",
99-
ti_id_str,
100-
previous_state,
101-
)
102-
103-
# TODO: Pass a RFC 9457 compliant error message in "detail" field
104-
# https://datatracker.ietf.org/doc/html/rfc9457
105-
# to provide more information about the error
106-
# FastAPI will automatically convert this to a JSON response
107-
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370
108-
raise HTTPException(
109-
status_code=status.HTTP_409_CONFLICT,
110-
detail={
111-
"reason": "invalid_state",
112-
"message": "TI was not in a state where it could be marked as running",
113-
"previous_state": previous_state,
114-
},
115-
)
116-
log.info("Task with %s state started on %s ", previous_state, ti_patch_payload.hostname)
117-
# Ensure there is no end date set.
118-
query = query.values(
119-
end_date=None,
120-
hostname=ti_patch_payload.hostname,
121-
unixname=ti_patch_payload.unixname,
122-
pid=ti_patch_payload.pid,
123-
state=State.RUNNING,
124-
)
125-
elif isinstance(ti_patch_payload, TITerminalStatePayload):
202+
if isinstance(ti_patch_payload, TITerminalStatePayload):
126203
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
127204
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
128205
# Calculate timeout if it was passed

task_sdk/src/airflow/sdk/api/client.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@
3030
from airflow.sdk import __version__
3131
from airflow.sdk.api.datamodels._generated import (
3232
ConnectionResponse,
33+
DagRunType,
3334
TerminalTIState,
3435
TIDeferredStatePayload,
3536
TIEnterRunningPayload,
3637
TIHeartbeatInfo,
38+
TIRunContext,
3739
TITerminalStatePayload,
3840
ValidationError as RemoteValidationError,
3941
VariablePostBody,
@@ -110,11 +112,12 @@ class TaskInstanceOperations:
110112
def __init__(self, client: Client):
111113
self.client = client
112114

113-
def start(self, id: uuid.UUID, pid: int, when: datetime):
115+
def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
114116
"""Tell the API server that this TI has started running."""
115117
body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when)
116118

117-
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
119+
resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
120+
return TIRunContext.model_validate_json(resp.read())
118121

119122
def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
120123
"""Tell the API server that this TI has reached a terminal state."""
@@ -218,7 +221,23 @@ def auth_flow(self, request: httpx.Request):
218221
# This exists as a aid for debugging or local running via the `dry_run` argument to Client. It doesn't make
219222
# sense for returning connections etc.
220223
def noop_handler(request: httpx.Request) -> httpx.Response:
221-
log.debug("Dry-run request", method=request.method, path=request.url.path)
224+
path = request.url.path
225+
log.debug("Dry-run request", method=request.method, path=path)
226+
227+
if path.startswith("/task-instances/") and path.endswith("/run"):
228+
# Return a fake context
229+
return httpx.Response(
230+
200,
231+
json={
232+
"dag_run": {
233+
"dag_id": "test_dag",
234+
"run_id": "test_run",
235+
"logical_date": "2021-01-01T00:00:00Z",
236+
"start_date": "2021-01-01T00:00:00Z",
237+
"run_type": DagRunType.MANUAL,
238+
},
239+
},
240+
)
222241
return httpx.Response(200, json={"text": "Hello, world!"})
223242

224243

task_sdk/src/airflow/sdk/api/datamodels/_generated.py

+37
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ class ConnectionResponse(BaseModel):
4444
extra: Annotated[str | None, Field(title="Extra")] = None
4545

4646

47+
class DagRunType(str, Enum):
48+
"""
49+
Class with DagRun types.
50+
"""
51+
52+
BACKFILL = "backfill"
53+
SCHEDULED = "scheduled"
54+
MANUAL = "manual"
55+
ASSET_TRIGGERED = "asset_triggered"
56+
57+
4758
class IntermediateTIState(str, Enum):
4859
"""
4960
States that a Task Instance can be in that indicate it is not yet in a terminal or running state.
@@ -159,10 +170,36 @@ class TaskInstance(BaseModel):
159170
map_index: Annotated[int | None, Field(title="Map Index")] = None
160171

161172

173+
class DagRun(BaseModel):
174+
"""
175+
Schema for DagRun model with minimal required fields needed for Runtime.
176+
"""
177+
178+
dag_id: Annotated[str, Field(title="Dag Id")]
179+
run_id: Annotated[str, Field(title="Run Id")]
180+
logical_date: Annotated[datetime, Field(title="Logical Date")]
181+
data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None
182+
data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None
183+
start_date: Annotated[datetime, Field(title="Start Date")]
184+
end_date: Annotated[datetime | None, Field(title="End Date")] = None
185+
run_type: DagRunType
186+
conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None
187+
188+
162189
class HTTPValidationError(BaseModel):
163190
detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None
164191

165192

193+
class TIRunContext(BaseModel):
194+
"""
195+
Response schema for TaskInstance run context.
196+
"""
197+
198+
dag_run: DagRun
199+
variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None
200+
connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None
201+
202+
166203
class TITerminalStatePayload(BaseModel):
167204
"""
168205
Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).

task_sdk/src/airflow/sdk/execution_time/comms.py

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
TaskInstance,
5555
TerminalTIState,
5656
TIDeferredStatePayload,
57+
TIRunContext,
5758
VariableResponse,
5859
XComResponse,
5960
)
@@ -70,6 +71,7 @@ class StartupDetails(BaseModel):
7071
7172
Responses will come back on stdin
7273
"""
74+
ti_context: TIRunContext
7375
type: Literal["StartupDetails"] = "StartupDetails"
7476

7577

task_sdk/src/airflow/sdk/execution_time/supervisor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ
397397
# We've forked, but the task won't start doing anything until we send it the StartupDetails
398398
# message. But before we do that, we need to tell the server it's started (so it has the chance to
399399
# tell us "no, stop!" for any reason)
400-
self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc))
400+
ti_context = self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc))
401401
self._last_successful_heartbeat = time.monotonic()
402402
except Exception:
403403
# On any error kill that subprocess!
@@ -408,6 +408,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ
408408
ti=ti,
409409
file=os.fspath(path),
410410
requests_fd=requests_fd,
411+
ti_context=ti_context,
411412
)
412413

413414
# Send the message to tell the process what it needs to execute

0 commit comments

Comments
 (0)