Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-84 Get Batch Task Instances #44051

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def get_task_instances(
)


@mark_fastapi_migration_done
@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE)
@provide_session
def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
Expand Down
82 changes: 59 additions & 23 deletions airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def depends(self, *args: Any, **kwargs: Any) -> Self:
pass


class _LimitFilter(BaseParam[int]):
class LimitFilter(BaseParam[int]):
"""Filter on the limit."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -75,19 +75,19 @@ def to_orm(self, select: Select) -> Select:

return select.limit(self.value)

def depends(self, limit: int = 100) -> _LimitFilter:
def depends(self, limit: int = 100) -> LimitFilter:
return self.set_value(limit)


class _OffsetFilter(BaseParam[int]):
class OffsetFilter(BaseParam[int]):
"""Filter on offset."""

def to_orm(self, select: Select) -> Select:
if self.value is None and self.skip_none:
return select
return select.offset(self.value)

def depends(self, offset: int = 0) -> _OffsetFilter:
def depends(self, offset: int = 0) -> OffsetFilter:
return self.set_value(offset)


Expand Down Expand Up @@ -115,18 +115,54 @@ def depends(self, only_active: bool = True) -> _OnlyActiveFilter:
return self.set_value(only_active)


class _DagIdsFilter(BaseParam[list[str]]):
"""Filter on multi-valued dag_ids param for DagRun."""
class DagIdsFilter(BaseParam[list[str]]):
"""Filter on dag ids."""

def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)
self.model = model

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(DagRun.dag_id.in_(self.value))
return select.where(self.model.dag_id.in_(self.value))
return select

def depends(self, dag_ids: list[str] = Query(None)) -> _DagIdsFilter:
def depends(self, dag_ids: list[str] = Query(None)) -> DagIdsFilter:
return self.set_value(dag_ids)


class DagRunIdsFilter(BaseParam[list[str]]):
"""Filter on dag run ids."""

def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)
self.model = model

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(self.model.run_id.in_(self.value))
return select

def depends(self, dag_run_ids: list[str] = Query(None)) -> DagRunIdsFilter:
return self.set_value(dag_run_ids)


class TaskIdsFilter(BaseParam[list[str]]):
"""Filter on task ids."""

def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)
self.model = model

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(self.model.task_id.in_(self.value))
return select

def depends(self, task_ids: list[str] = Query(None)) -> TaskIdsFilter:
return self.set_value(task_ids)


class _SearchParam(BaseParam[str]):
"""Search on attribute."""

Expand Down Expand Up @@ -273,7 +309,7 @@ def depends(self, owners: list[str] = Query(default_factory=list)) -> _OwnersFil
return self.set_value(owners)


class _TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]):
class TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]):
"""Filter on task instance state."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -286,12 +322,12 @@ def to_orm(self, select: Select) -> Select:
conditions = [TaskInstance.state == state for state in self.value]
return select.where(or_(*conditions))

def depends(self, state: list[str] = Query(default_factory=list)) -> _TIStateFilter:
def depends(self, state: list[str] = Query(default_factory=list)) -> TIStateFilter:
states = _convert_ti_states(state)
return self.set_value(states)


class _TIPoolFilter(BaseParam[List[str]]):
class TIPoolFilter(BaseParam[List[str]]):
"""Filter on task instance pool."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -304,11 +340,11 @@ def to_orm(self, select: Select) -> Select:
conditions = [TaskInstance.pool == pool for pool in self.value]
return select.where(or_(*conditions))

def depends(self, pool: list[str] = Query(default_factory=list)) -> _TIPoolFilter:
def depends(self, pool: list[str] = Query(default_factory=list)) -> TIPoolFilter:
return self.set_value(pool)


class _TIQueueFilter(BaseParam[List[str]]):
class TIQueueFilter(BaseParam[List[str]]):
"""Filter on task instance queue."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -321,11 +357,11 @@ def to_orm(self, select: Select) -> Select:
conditions = [TaskInstance.queue == queue for queue in self.value]
return select.where(or_(*conditions))

def depends(self, queue: list[str] = Query(default_factory=list)) -> _TIQueueFilter:
def depends(self, queue: list[str] = Query(default_factory=list)) -> TIQueueFilter:
return self.set_value(queue)


class _TIExecutorFilter(BaseParam[List[str]]):
class TIExecutorFilter(BaseParam[List[str]]):
"""Filter on task instance executor."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -338,7 +374,7 @@ def to_orm(self, select: Select) -> Select:
conditions = [TaskInstance.executor == executor for executor in self.value]
return select.where(or_(*conditions))

def depends(self, executor: list[str] = Query(default_factory=list)) -> _TIExecutorFilter:
def depends(self, executor: list[str] = Query(default_factory=list)) -> TIExecutorFilter:
return self.set_value(executor)


Expand Down Expand Up @@ -581,8 +617,8 @@ def depends_float(
DateTimeQuery = Annotated[str, AfterValidator(_safe_parse_datetime)]

# DAG
QueryLimit = Annotated[_LimitFilter, Depends(_LimitFilter().depends)]
QueryOffset = Annotated[_OffsetFilter, Depends(_OffsetFilter().depends)]
QueryLimit = Annotated[LimitFilter, Depends(LimitFilter().depends)]
QueryOffset = Annotated[OffsetFilter, Depends(OffsetFilter().depends)]
QueryPausedFilter = Annotated[_PausedFilter, Depends(_PausedFilter().depends)]
QueryOnlyActiveFilter = Annotated[_OnlyActiveFilter, Depends(_OnlyActiveFilter().depends)]
QueryDagIdPatternSearch = Annotated[_DagIdPatternSearch, Depends(_DagIdPatternSearch().depends)]
Expand All @@ -597,7 +633,7 @@ def depends_float(

# DagRun
QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)]
QueryDagIdsFilter = Annotated[_DagIdsFilter, Depends(_DagIdsFilter().depends)]
QueryDagIdsFilter = Annotated[DagIdsFilter, Depends(DagIdsFilter(DagRun).depends)]

# DAGWarning
QueryDagIdInDagWarningFilter = Annotated[_DagIdFilter, Depends(_DagIdFilter(DagWarning.dag_id).depends)]
Expand All @@ -607,10 +643,10 @@ def depends_float(
QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch, Depends(_DagTagNamePatternSearch().depends)]

# TI
QueryTIStateFilter = Annotated[_TIStateFilter, Depends(_TIStateFilter().depends)]
QueryTIPoolFilter = Annotated[_TIPoolFilter, Depends(_TIPoolFilter().depends)]
QueryTIQueueFilter = Annotated[_TIQueueFilter, Depends(_TIQueueFilter().depends)]
QueryTIExecutorFilter = Annotated[_TIExecutorFilter, Depends(_TIExecutorFilter().depends)]
QueryTIStateFilter = Annotated[TIStateFilter, Depends(TIStateFilter().depends)]
QueryTIPoolFilter = Annotated[TIPoolFilter, Depends(TIPoolFilter().depends)]
QueryTIQueueFilter = Annotated[TIQueueFilter, Depends(TIQueueFilter().depends)]
QueryTIExecutorFilter = Annotated[TIExecutorFilter, Depends(TIExecutorFilter().depends)]

# Assets
QueryUriPatternSearch = Annotated[_UriPatternSearch, Depends(_UriPatternSearch().depends)]
Expand Down
33 changes: 32 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
from datetime import datetime
from typing import Annotated

from pydantic import AliasPath, BaseModel, BeforeValidator, ConfigDict, Field
from pydantic import (
AliasPath,
AwareDatetime,
BaseModel,
BeforeValidator,
ConfigDict,
Field,
NonNegativeInt,
)

from airflow.api_fastapi.core_api.datamodels.job import JobResponse
from airflow.api_fastapi.core_api.datamodels.trigger import TriggerResponse
Expand Down Expand Up @@ -83,3 +91,26 @@ class TaskDependencyCollectionResponse(BaseModel):
"""Task scheduling dependencies collection serializer for responses."""

dependencies: list[TaskDependencyResponse]


class TaskInstancesBatchBody(BaseModel):
"""Task Instance body for get batch."""

dag_ids: list[str] | None = None
dag_run_ids: list[str] | None = None
task_ids: list[str] | None = None
state: list[TaskInstanceState | None] | None = None
logical_date_gte: AwareDatetime | None = None
logical_date_lte: AwareDatetime | None = None
start_date_gte: AwareDatetime | None = None
start_date_lte: AwareDatetime | None = None
end_date_gte: AwareDatetime | None = None
end_date_lte: AwareDatetime | None = None
duration_gte: float | None = None
duration_lte: float | None = None
pool: list[str] | None = None
queue: list[str] | None = None
executor: list[str] | None = None
page_offset: NonNegativeInt = 0
page_limit: NonNegativeInt = 100
order_by: str | None = None
Loading