Skip to content

Commit

Permalink
AIP-72: Handling "up_for_reschedule" task instance states (#44907)
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh authored Dec 18, 2024
1 parent 9c4d711 commit 084218b
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 5 deletions.
21 changes: 21 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ class TIDeferredStatePayload(BaseModel):
trigger_timeout: timedelta | None = None


class TIRescheduleStatePayload(BaseModel):
"""Schema for updating TaskInstance to a up_for_reschedule state."""

state: Annotated[
Literal[IntermediateTIState.UP_FOR_RESCHEDULE],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
WithJsonSchema(
{
"type": "string",
"enum": [IntermediateTIState.UP_FOR_RESCHEDULE],
"default": IntermediateTIState.UP_FOR_RESCHEDULE,
}
),
]
reschedule_date: UtcDateTime
end_date: UtcDateTime


def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
"""
Determine the discriminator key for TaskInstance state transitions.
Expand All @@ -101,6 +119,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
return "_terminal_"
elif state == TIState.DEFERRED:
return "deferred"
elif state == TIState.UP_FOR_RESCHEDULE:
return "up_for_reschedule"
return "_other_"


Expand All @@ -111,6 +131,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
Annotated[TITerminalStatePayload, Tag("_terminal_")],
Annotated[TITargetStatePayload, Tag("_other_")],
Annotated[TIDeferredStatePayload, Tag("deferred")],
Annotated[TIRescheduleStatePayload, Tag("up_for_reschedule")],
],
Discriminator(ti_state_discriminator),
]
Expand Down
22 changes: 22 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRunContext,
TIStateUpdate,
TITerminalStatePayload,
)
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.utils import timezone
from airflow.utils.state import State
Expand Down Expand Up @@ -228,7 +230,27 @@ def ti_update_state(
next_kwargs=ti_patch_payload.trigger_kwargs,
trigger_timeout=timeout,
)
elif isinstance(ti_patch_payload, TIRescheduleStatePayload):
task_instance = session.get(TI, ti_id_str)
actual_start_date = timezone.utcnow()
session.add(
TaskReschedule(
task_instance.task_id,
task_instance.dag_id,
task_instance.run_id,
task_instance.try_number,
actual_start_date,
ti_patch_payload.end_date,
ti_patch_payload.reschedule_date,
task_instance.map_index,
)
)

query = update(TI).where(TI.id == ti_id_str)
# calculate the duration for TI table too
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
# clear the next_method and next_kwargs so that none of the retries pick them up
query = query.values(state=State.UP_FOR_RESCHEDULE, next_method=None, next_kwargs=None)
# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRunContext,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
Expand All @@ -48,6 +49,7 @@
if TYPE_CHECKING:
from datetime import datetime

from airflow.sdk.execution_time.comms import RescheduleTask
from airflow.typing_compat import ParamSpec

P = ParamSpec("P")
Expand Down Expand Up @@ -137,6 +139,13 @@ def defer(self, id: uuid.UUID, msg):
# Create a deferred state payload from msg
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def reschedule(self, id: uuid.UUID, msg: RescheduleTask):
"""Tell the API server that this TI has been reschduled."""
body = TIRescheduleStatePayload(**msg.model_dump(exclude_unset=True))

# Create a reschedule state payload from msg
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> dict[str, bool]:
"""Set Rendered Task Instance Fields via the API server."""
self.client.put(f"task-instances/{id}/rtif", json=body)
Expand Down
10 changes: 10 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ class TIHeartbeatInfo(BaseModel):
pid: Annotated[int, Field(title="Pid")]


class TIRescheduleStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a up_for_reschedule state.
"""

state: Annotated[Literal["up_for_reschedule"] | None, Field(title="State")] = "up_for_reschedule"
end_date: Annotated[datetime, Field(title="End Date")]
reschedule_date: Annotated[datetime, Field(title="Reschedule Date")]


class TITargetStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a target state, excluding terminal and running states.
Expand Down
19 changes: 18 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
TaskInstance,
TerminalTIState,
TIDeferredStatePayload,
TIRescheduleStatePayload,
TIRunContext,
VariableResponse,
XComResponse,
Expand Down Expand Up @@ -115,6 +116,12 @@ class DeferTask(TIDeferredStatePayload):
type: Literal["DeferTask"] = "DeferTask"


class RescheduleTask(TIRescheduleStatePayload):
"""Update a task instance state to reschedule/up_for_reschedule."""

type: Literal["RescheduleTask"] = "RescheduleTask"


class GetXCom(BaseModel):
key: str
dag_id: str
Expand Down Expand Up @@ -183,6 +190,16 @@ class SetRenderedFields(BaseModel):


ToSupervisor = Annotated[
Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask, PutVariable, SetXCom, SetRenderedFields],
Union[
TaskState,
GetXCom,
GetConnection,
GetVariable,
DeferTask,
PutVariable,
SetXCom,
SetRenderedFields,
RescheduleTask,
],
Field(discriminator="type"),
]
4 changes: 4 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
GetVariable,
GetXCom,
PutVariable,
RescheduleTask,
SetXCom,
StartupDetails,
TaskState,
Expand Down Expand Up @@ -698,6 +699,9 @@ def _handle_request(self, msg, log):
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.id, msg)
elif isinstance(msg, RescheduleTask):
self._terminal_state = IntermediateTIState.UP_FOR_RESCHEDULE
self.client.task_instances.reschedule(self.id, msg)
elif isinstance(msg, SetXCom):
self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index)
elif isinstance(msg, PutVariable):
Expand Down
7 changes: 5 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
RescheduleTask,
SetRenderedFields,
StartupDetails,
TaskState,
Expand Down Expand Up @@ -279,8 +280,10 @@ def run(ti: RuntimeTaskInstance, log: Logger):
state=TerminalTIState.SKIPPED,
end_date=datetime.now(tz=timezone.utc),
)
except AirflowRescheduleException:
...
except AirflowRescheduleException as reschedule:
msg = RescheduleTask(
reschedule_date=reschedule.reschedule_date, end_date=datetime.now(tz=timezone.utc)
)
except (AirflowFailException, AirflowSensorTimeout):
# If AirflowFailException is raised, task should not retry.
# If a sensor in reschedule mode reaches timeout, task should not retry.
Expand Down
25 changes: 24 additions & 1 deletion task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError
from airflow.sdk.api.datamodels._generated import VariableResponse, XComResponse
from airflow.sdk.execution_time.comms import DeferTask
from airflow.sdk.execution_time.comms import DeferTask, RescheduleTask
from airflow.utils import timezone
from airflow.utils.state import TerminalTIState


Expand Down Expand Up @@ -183,6 +184,28 @@ def handle_request(request: httpx.Request) -> httpx.Response:
)
client.task_instances.defer(ti_id, msg)

def test_task_instance_reschedule(self):
# Simulate a successful response from the server that reschedules a task
ti_id = uuid6.uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["state"] == "up_for_reschedule"
assert actual_body["reschedule_date"] == "2024-10-31T12:00:00Z"
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
msg = RescheduleTask(
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
end_date=timezone.parse("2024-10-31T12:00:00Z"),
)
client.task_instances.reschedule(ti_id, msg)

@pytest.mark.parametrize(
"rendered_fields",
[
Expand Down
18 changes: 18 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
GetVariable,
GetXCom,
PutVariable,
RescheduleTask,
SetXCom,
TaskState,
VariableResult,
Expand Down Expand Up @@ -793,6 +794,23 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_deferred",
),
pytest.param(
RescheduleTask(
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
end_date=timezone.parse("2024-10-31T12:00:00Z"),
),
b"",
"task_instances.reschedule",
(
TI_ID,
RescheduleTask(
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
end_date=timezone.parse("2024-10-31T12:00:00Z"),
),
),
"",
id="patch_task_instance_to_up_for_reschedule",
),
pytest.param(
GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"),
b'{"key":"test_key","value":"test_value"}\n',
Expand Down
50 changes: 49 additions & 1 deletion tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError

from airflow.models import RenderedTaskInstanceFields, Trigger
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState
Expand Down Expand Up @@ -286,6 +286,54 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance
assert t[0].classpath == "my-classpath"
assert t[0].kwargs == {"key": "value"}

def test_ti_update_state_to_reschedule(self, client, session, create_task_instance, time_machine):
"""
Test that tests if the transition to reschedule state is handled correctly.
"""

instant = timezone.datetime(2024, 10, 30)
time_machine.move_to(instant, tick=False)

ti = create_task_instance(
task_id="test_ti_update_state_to_reschedule",
state=State.RUNNING,
session=session,
)
ti.start_date = instant
session.commit()

payload = {
"state": "up_for_reschedule",
"reschedule_date": "2024-10-31T11:03:00+00:00",
"end_date": DEFAULT_END_DATE.isoformat(),
}

response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

tis = session.query(TaskInstance).all()
assert len(tis) == 1
assert tis[0].state == TaskInstanceState.UP_FOR_RESCHEDULE
assert tis[0].next_method is None
assert tis[0].next_kwargs is None
assert tis[0].duration == 129600

trs = session.query(TaskReschedule).all()
assert len(trs) == 1
assert trs[0].dag_id == "dag"
assert trs[0].task_id == "test_ti_update_state_to_reschedule"
assert trs[0].run_id == "test"
assert trs[0].try_number == 0
assert trs[0].start_date == instant
assert trs[0].end_date == DEFAULT_END_DATE
assert trs[0].reschedule_date == timezone.parse("2024-10-31T11:03:00+00:00")
assert trs[0].map_index == -1
assert trs[0].duration == 129600


class TestTIHealthEndpoint:
def setup_method(self):
Expand Down

0 comments on commit 084218b

Please sign in to comment.