Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Pass context keys from API Server to Workers #44899

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
from airflow.utils.types import DagRunType


class TIEnterRunningPayload(BaseModel):
Expand Down Expand Up @@ -94,9 +97,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
state = v.get("state")
else:
state = getattr(v, "state", None)
if state == TIState.RUNNING:
return str(state)
elif state in set(TerminalTIState):
if state in set(TerminalTIState):
return "_terminal_"
elif state == TIState.DEFERRED:
return "deferred"
Expand All @@ -107,7 +108,6 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
# and "_other_" is a catch-all for all other states that are not covered by the other schemas.
TIStateUpdate = Annotated[
Union[
Annotated[TIEnterRunningPayload, Tag("running")],
Annotated[TITerminalStatePayload, Tag("_terminal_")],
Annotated[TITargetStatePayload, Tag("_other_")],
Annotated[TIDeferredStatePayload, Tag("deferred")],
Expand Down Expand Up @@ -135,3 +135,34 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None


class DagRun(BaseModel):
"""Schema for DagRun model with minimal required fields needed for Runtime."""

# TODO: `dag_id` and `run_id` are duplicated from TaskInstance
# See if we can avoid sending these fields from API server and instead
# use the TaskInstance data to get the DAG run information in the client (Task Execution Interface).
dag_id: str
run_id: str

logical_date: UtcDateTime
data_interval_start: UtcDateTime | None
data_interval_end: UtcDateTime | None
start_date: UtcDateTime
end_date: UtcDateTime | None
run_type: DagRunType
conf: Annotated[dict[str, Any], Field(default_factory=dict)]


class TIRunContext(BaseModel):
"""Response schema for TaskInstance run context."""

dag_run: DagRun
"""DAG run information for the task instance."""

variables: Annotated[list[VariableResponse], Field(default_factory=list)]
"""Variables that can be accessed by the task instance."""

connections: Annotated[list[ConnectionResponse], Field(default_factory=list)]
"""Connections that can be accessed by the task instance."""
139 changes: 108 additions & 31 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
DagRun,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRunContext,
TIStateUpdate,
TITerminalStatePayload,
)
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.trigger import Trigger
from airflow.utils import timezone
Expand All @@ -48,6 +51,110 @@
log = logging.getLogger(__name__)


@router.patch(
"/{task_instance_id}/run",
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"},
},
)
def ti_run(
task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep
) -> TIRunContext:
"""
Run a TaskInstance.

This endpoint is used to start a TaskInstance that is in the QUEUED state.
"""
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)

old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update()
try:
(previous_state, dag_id, run_id) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": "Task Instance not found",
},
)

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_run_payload.model_dump(exclude_unset=True)

query = update(TI).where(TI.id == ti_id_str).values(data)

# TODO: We will need to change this for other states like:
# reschedule, retry, defer etc.
if previous_state != State.QUEUED:
log.warning(
"Can not start Task Instance ('%s') in invalid state: %s",
ti_id_str,
previous_state,
)

# TODO: Pass a RFC 9457 compliant error message in "detail" field
# https://datatracker.ietf.org/doc/html/rfc9457
# to provide more information about the error
# FastAPI will automatically convert this to a JSON response
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": previous_state,
},
)
log.info("Task with %s state started on %s ", previous_state, ti_run_payload.hostname)
# Ensure there is no end date set.
query = query.values(
end_date=None,
hostname=ti_run_payload.hostname,
unixname=ti_run_payload.unixname,
pid=ti_run_payload.pid,
state=State.RUNNING,
)

try:
result = session.execute(query)
log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount)

dr = session.execute(
select(
DR.run_id,
DR.dag_id,
DR.data_interval_start,
DR.data_interval_end,
DR.start_date,
DR.end_date,
DR.run_type,
DR.conf,
DR.logical_date,
).filter_by(dag_id=dag_id, run_id=run_id)
).one_or_none()

if not dr:
raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} not found.")

return TIRunContext(
dag_run=DagRun.model_validate(dr, from_attributes=True),
# TODO: Add variables and connections that are needed (and has perms) for the task
variables=[],
connections=[],
)
except SQLAlchemyError as e:
log.error("Error marking Task Instance state as running: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Database error occurred"
)


@router.patch(
"/{task_instance_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
Expand Down Expand Up @@ -92,37 +199,7 @@ def ti_update_state(

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, TIEnterRunningPayload):
if previous_state != State.QUEUED:
log.warning(
"Can not start Task Instance ('%s') in invalid state: %s",
ti_id_str,
previous_state,
)

# TODO: Pass a RFC 9457 compliant error message in "detail" field
# https://datatracker.ietf.org/doc/html/rfc9457
# to provide more information about the error
# FastAPI will automatically convert this to a JSON response
# This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail={
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": previous_state,
},
)
log.info("Task with %s state started on %s ", previous_state, ti_patch_payload.hostname)
# Ensure there is no end date set.
query = query.values(
end_date=None,
hostname=ti_patch_payload.hostname,
unixname=ti_patch_payload.unixname,
pid=ti_patch_payload.pid,
state=State.RUNNING,
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
Expand Down
25 changes: 22 additions & 3 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
DagRunType,
TerminalTIState,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRunContext,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
VariablePostBody,
Expand Down Expand Up @@ -110,11 +112,12 @@ class TaskInstanceOperations:
def __init__(self, client: Client):
self.client = client

def start(self, id: uuid.UUID, pid: int, when: datetime):
def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
"""Tell the API server that this TI has started running."""
body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when)

self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
return TIRunContext.model_validate_json(resp.read())

def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
Expand Down Expand Up @@ -218,7 +221,23 @@ def auth_flow(self, request: httpx.Request):
# This exists as a aid for debugging or local running via the `dry_run` argument to Client. It doesn't make
# sense for returning connections etc.
def noop_handler(request: httpx.Request) -> httpx.Response:
log.debug("Dry-run request", method=request.method, path=request.url.path)
path = request.url.path
log.debug("Dry-run request", method=request.method, path=path)

if path.startswith("/task-instances/") and path.endswith("/run"):
# Return a fake context
return httpx.Response(
200,
json={
"dag_run": {
"dag_id": "test_dag",
"run_id": "test_run",
"logical_date": "2021-01-01T00:00:00Z",
"start_date": "2021-01-01T00:00:00Z",
"run_type": DagRunType.MANUAL,
},
},
)
return httpx.Response(200, json={"text": "Hello, world!"})


Expand Down
37 changes: 37 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ class ConnectionResponse(BaseModel):
extra: Annotated[str | None, Field(title="Extra")] = None


class DagRunType(str, Enum):
"""
Class with DagRun types.
"""

BACKFILL = "backfill"
SCHEDULED = "scheduled"
MANUAL = "manual"
ASSET_TRIGGERED = "asset_triggered"


class IntermediateTIState(str, Enum):
"""
States that a Task Instance can be in that indicate it is not yet in a terminal or running state.
Expand Down Expand Up @@ -159,10 +170,36 @@ class TaskInstance(BaseModel):
map_index: Annotated[int | None, Field(title="Map Index")] = None


class DagRun(BaseModel):
"""
Schema for DagRun model with minimal required fields needed for Runtime.
"""

dag_id: Annotated[str, Field(title="Dag Id")]
run_id: Annotated[str, Field(title="Run Id")]
logical_date: Annotated[datetime, Field(title="Logical Date")]
data_interval_start: Annotated[datetime | None, Field(title="Data Interval Start")] = None
data_interval_end: Annotated[datetime | None, Field(title="Data Interval End")] = None
start_date: Annotated[datetime, Field(title="Start Date")]
end_date: Annotated[datetime | None, Field(title="End Date")] = None
run_type: DagRunType
conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None


class HTTPValidationError(BaseModel):
detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = None


class TIRunContext(BaseModel):
"""
Response schema for TaskInstance run context.
"""

dag_run: DagRun
variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None
connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None


class TITerminalStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).
Expand Down
2 changes: 2 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
TaskInstance,
TerminalTIState,
TIDeferredStatePayload,
TIRunContext,
VariableResponse,
XComResponse,
)
Expand All @@ -70,6 +71,7 @@ class StartupDetails(BaseModel):

Responses will come back on stdin
"""
ti_context: TIRunContext
type: Literal["StartupDetails"] = "StartupDetails"


Expand Down
3 changes: 2 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ
# We've forked, but the task won't start doing anything until we send it the StartupDetails
# message. But before we do that, we need to tell the server it's started (so it has the chance to
# tell us "no, stop!" for any reason)
self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc))
ti_context = self.client.task_instances.start(ti.id, self.pid, datetime.now(tz=timezone.utc))
self._last_successful_heartbeat = time.monotonic()
except Exception:
# On any error kill that subprocess!
Expand All @@ -408,6 +408,7 @@ def _on_child_started(self, ti: TaskInstance, path: str | os.PathLike[str], requ
ti=ti,
file=os.fspath(path),
requests_fd=requests_fd,
ti_context=ti_context,
)

# Send the message to tell the process what it needs to execute
Expand Down
Loading