diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 68267317af3b6c..83456960753d65 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -367,6 +367,7 @@ def get_task_instances( ) +@mark_fastapi_migration_done @security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index abf6378ac5b293..337d85547c3d04 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -66,7 +66,7 @@ def depends(self, *args: Any, **kwargs: Any) -> Self: pass -class _LimitFilter(BaseParam[int]): +class LimitFilter(BaseParam[int]): """Filter on the limit.""" def to_orm(self, select: Select) -> Select: @@ -75,11 +75,11 @@ def to_orm(self, select: Select) -> Select: return select.limit(self.value) - def depends(self, limit: int = 100) -> _LimitFilter: + def depends(self, limit: int = 100) -> LimitFilter: return self.set_value(limit) -class _OffsetFilter(BaseParam[int]): +class OffsetFilter(BaseParam[int]): """Filter on offset.""" def to_orm(self, select: Select) -> Select: @@ -87,7 +87,7 @@ def to_orm(self, select: Select) -> Select: return select return select.offset(self.value) - def depends(self, offset: int = 0) -> _OffsetFilter: + def depends(self, offset: int = 0) -> OffsetFilter: return self.set_value(offset) @@ -115,18 +115,54 @@ def depends(self, only_active: bool = True) -> _OnlyActiveFilter: return self.set_value(only_active) -class _DagIdsFilter(BaseParam[list[str]]): - """Filter on multi-valued dag_ids param for DagRun.""" +class DagIdsFilter(BaseParam[list[str]]): + """Filter on dag ids.""" + + def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None: + super().__init__(value, skip_none) + self.model = model def to_orm(self, select: Select) -> Select: if self.value and self.skip_none: - return select.where(DagRun.dag_id.in_(self.value)) + return select.where(self.model.dag_id.in_(self.value)) return select - def depends(self, dag_ids: list[str] = Query(None)) -> _DagIdsFilter: + def depends(self, dag_ids: list[str] = Query(None)) -> DagIdsFilter: return self.set_value(dag_ids) +class DagRunIdsFilter(BaseParam[list[str]]): + """Filter on dag run ids.""" + + def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None: + super().__init__(value, skip_none) + self.model = model + + def to_orm(self, select: Select) -> Select: + if self.value and self.skip_none: + return select.where(self.model.run_id.in_(self.value)) + return select + + def depends(self, dag_run_ids: list[str] = Query(None)) -> DagRunIdsFilter: + return self.set_value(dag_run_ids) + + +class TaskIdsFilter(BaseParam[list[str]]): + """Filter on task ids.""" + + def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None: + super().__init__(value, skip_none) + self.model = model + + def to_orm(self, select: Select) -> Select: + if self.value and self.skip_none: + return select.where(self.model.task_id.in_(self.value)) + return select + + def depends(self, task_ids: list[str] = Query(None)) -> TaskIdsFilter: + return self.set_value(task_ids) + + class _SearchParam(BaseParam[str]): """Search on attribute.""" @@ -273,7 +309,7 @@ def depends(self, owners: list[str] = Query(default_factory=list)) -> _OwnersFil return self.set_value(owners) -class _TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]): +class TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]): """Filter on task instance state.""" def to_orm(self, select: Select) -> Select: @@ -286,12 +322,12 @@ def to_orm(self, select: Select) -> Select: conditions = [TaskInstance.state == state for state in self.value] return select.where(or_(*conditions)) - def depends(self, state: list[str] = Query(default_factory=list)) -> _TIStateFilter: + def depends(self, state: list[str] = Query(default_factory=list)) -> TIStateFilter: states = _convert_ti_states(state) return self.set_value(states) -class _TIPoolFilter(BaseParam[List[str]]): +class TIPoolFilter(BaseParam[List[str]]): """Filter on task instance pool.""" def to_orm(self, select: Select) -> Select: @@ -304,11 +340,11 @@ def to_orm(self, select: Select) -> Select: conditions = [TaskInstance.pool == pool for pool in self.value] return select.where(or_(*conditions)) - def depends(self, pool: list[str] = Query(default_factory=list)) -> _TIPoolFilter: + def depends(self, pool: list[str] = Query(default_factory=list)) -> TIPoolFilter: return self.set_value(pool) -class _TIQueueFilter(BaseParam[List[str]]): +class TIQueueFilter(BaseParam[List[str]]): """Filter on task instance queue.""" def to_orm(self, select: Select) -> Select: @@ -321,11 +357,11 @@ def to_orm(self, select: Select) -> Select: conditions = [TaskInstance.queue == queue for queue in self.value] return select.where(or_(*conditions)) - def depends(self, queue: list[str] = Query(default_factory=list)) -> _TIQueueFilter: + def depends(self, queue: list[str] = Query(default_factory=list)) -> TIQueueFilter: return self.set_value(queue) -class _TIExecutorFilter(BaseParam[List[str]]): +class TIExecutorFilter(BaseParam[List[str]]): """Filter on task instance executor.""" def to_orm(self, select: Select) -> Select: @@ -338,7 +374,7 @@ def to_orm(self, select: Select) -> Select: conditions = [TaskInstance.executor == executor for executor in self.value] return select.where(or_(*conditions)) - def depends(self, executor: list[str] = Query(default_factory=list)) -> _TIExecutorFilter: + def depends(self, executor: list[str] = Query(default_factory=list)) -> TIExecutorFilter: return self.set_value(executor) @@ -581,8 +617,8 @@ def depends_float( DateTimeQuery = Annotated[str, AfterValidator(_safe_parse_datetime)] # DAG -QueryLimit = Annotated[_LimitFilter, Depends(_LimitFilter().depends)] -QueryOffset = Annotated[_OffsetFilter, Depends(_OffsetFilter().depends)] +QueryLimit = Annotated[LimitFilter, Depends(LimitFilter().depends)] +QueryOffset = Annotated[OffsetFilter, Depends(OffsetFilter().depends)] QueryPausedFilter = Annotated[_PausedFilter, Depends(_PausedFilter().depends)] QueryOnlyActiveFilter = Annotated[_OnlyActiveFilter, Depends(_OnlyActiveFilter().depends)] QueryDagIdPatternSearch = Annotated[_DagIdPatternSearch, Depends(_DagIdPatternSearch().depends)] @@ -597,7 +633,7 @@ def depends_float( # DagRun QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)] -QueryDagIdsFilter = Annotated[_DagIdsFilter, Depends(_DagIdsFilter().depends)] +QueryDagIdsFilter = Annotated[DagIdsFilter, Depends(DagIdsFilter(DagRun).depends)] # DAGWarning QueryDagIdInDagWarningFilter = Annotated[_DagIdFilter, Depends(_DagIdFilter(DagWarning.dag_id).depends)] @@ -607,10 +643,10 @@ def depends_float( QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch, Depends(_DagTagNamePatternSearch().depends)] # TI -QueryTIStateFilter = Annotated[_TIStateFilter, Depends(_TIStateFilter().depends)] -QueryTIPoolFilter = Annotated[_TIPoolFilter, Depends(_TIPoolFilter().depends)] -QueryTIQueueFilter = Annotated[_TIQueueFilter, Depends(_TIQueueFilter().depends)] -QueryTIExecutorFilter = Annotated[_TIExecutorFilter, Depends(_TIExecutorFilter().depends)] +QueryTIStateFilter = Annotated[TIStateFilter, Depends(TIStateFilter().depends)] +QueryTIPoolFilter = Annotated[TIPoolFilter, Depends(TIPoolFilter().depends)] +QueryTIQueueFilter = Annotated[TIQueueFilter, Depends(TIQueueFilter().depends)] +QueryTIExecutorFilter = Annotated[TIExecutorFilter, Depends(TIExecutorFilter().depends)] # Assets QueryUriPatternSearch = Annotated[_UriPatternSearch, Depends(_UriPatternSearch().depends)] diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py b/airflow/api_fastapi/core_api/datamodels/task_instances.py index a54b85e58ddfb6..cd4caf1b6119a9 100644 --- a/airflow/api_fastapi/core_api/datamodels/task_instances.py +++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py @@ -19,7 +19,15 @@ from datetime import datetime from typing import Annotated -from pydantic import AliasPath, BaseModel, BeforeValidator, ConfigDict, Field +from pydantic import ( + AliasPath, + AwareDatetime, + BaseModel, + BeforeValidator, + ConfigDict, + Field, + NonNegativeInt, +) from airflow.api_fastapi.core_api.datamodels.job import JobResponse from airflow.api_fastapi.core_api.datamodels.trigger import TriggerResponse @@ -83,3 +91,26 @@ class TaskDependencyCollectionResponse(BaseModel): """Task scheduling dependencies collection serializer for responses.""" dependencies: list[TaskDependencyResponse] + + +class TaskInstancesBatchBody(BaseModel): + """Task Instance body for get batch.""" + + dag_ids: list[str] | None = None + dag_run_ids: list[str] | None = None + task_ids: list[str] | None = None + state: list[TaskInstanceState | None] | None = None + logical_date_gte: AwareDatetime | None = None + logical_date_lte: AwareDatetime | None = None + start_date_gte: AwareDatetime | None = None + start_date_lte: AwareDatetime | None = None + end_date_gte: AwareDatetime | None = None + end_date_lte: AwareDatetime | None = None + duration_gte: float | None = None + duration_lte: float | None = None + pool: list[str] | None = None + queue: list[str] | None = None + executor: list[str] | None = None + page_offset: NonNegativeInt = 0 + page_limit: NonNegativeInt = 100 + order_by: str | None = None diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 679eedb0479174..e7762392a0c8b1 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3299,6 +3299,50 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/list: + post: + tags: + - Task Instance + summary: Get Task Instances Batch + description: Get list of task instances. + operationId: get_task_instances_batch + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstancesBatchBody' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstanceCollectionResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '404': + description: Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}/tasks/: get: tags: @@ -6006,6 +6050,123 @@ components: - deferred title: TaskInstanceStateCount description: TaskInstance serializer for responses. + TaskInstancesBatchBody: + properties: + dag_ids: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Dag Ids + dag_run_ids: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Dag Run Ids + task_ids: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Task Ids + state: + anyOf: + - items: + anyOf: + - $ref: '#/components/schemas/TaskInstanceState' + - type: 'null' + type: array + - type: 'null' + title: State + logical_date_gte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Logical Date Gte + logical_date_lte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Logical Date Lte + start_date_gte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Start Date Gte + start_date_lte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Start Date Lte + end_date_gte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: End Date Gte + end_date_lte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: End Date Lte + duration_gte: + anyOf: + - type: number + - type: 'null' + title: Duration Gte + duration_lte: + anyOf: + - type: number + - type: 'null' + title: Duration Lte + pool: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Pool + queue: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Queue + executor: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Executor + page_offset: + type: integer + minimum: 0.0 + title: Page Offset + default: 0 + page_limit: + type: integer + minimum: 0.0 + title: Page Limit + default: 100 + order_by: + anyOf: + - type: string + - type: 'null' + title: Order By + type: object + title: TaskInstancesBatchBody + description: Task Instance body for get batch. TaskOutletAssetReference: properties: dag_id: diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index 516e113fc24f1d..271f75e69e684b 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -25,14 +25,24 @@ from airflow.api_fastapi.common.db.common import get_session, paginated_select from airflow.api_fastapi.common.parameters import ( + DagIdsFilter, + DagRunIdsFilter, + LimitFilter, + OffsetFilter, QueryLimit, QueryOffset, QueryTIExecutorFilter, QueryTIPoolFilter, QueryTIQueueFilter, QueryTIStateFilter, + Range, RangeFilter, SortParam, + TaskIdsFilter, + TIExecutorFilter, + TIPoolFilter, + TIQueueFilter, + TIStateFilter, datetime_range_filter_factory, float_range_filter_factory, ) @@ -41,6 +51,7 @@ TaskDependencyCollectionResponse, TaskInstanceCollectionResponse, TaskInstanceResponse, + TaskInstancesBatchBody, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.exceptions import TaskNotFound @@ -158,7 +169,7 @@ def get_mapped_task_instances( session, ) - task_instances = session.scalars(task_instance_select).all() + task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( task_instances=[ @@ -322,7 +333,85 @@ def get_task_instances( session, ) - task_instances = session.scalars(task_instance_select).all() + task_instances = session.scalars(task_instance_select) + + return TaskInstanceCollectionResponse( + task_instances=[ + TaskInstanceResponse.model_validate(task_instance, from_attributes=True) + for task_instance in task_instances + ], + total_entries=total_entries, + ) + + +@task_instances_router.post( + "/list", + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), +) +def get_task_instances_batch( + body: TaskInstancesBatchBody, + session: Annotated[Session, Depends(get_session)], +) -> TaskInstanceCollectionResponse: + """Get list of task instances.""" + dag_ids = DagIdsFilter(TI, body.dag_ids) + dag_run_ids = DagRunIdsFilter(TI, body.dag_run_ids) + task_ids = TaskIdsFilter(TI, body.task_ids) + logical_date = RangeFilter( + Range(lower_bound=body.logical_date_gte, upper_bound=body.logical_date_lte), + attribute=TI.logical_date, + ) + start_date = RangeFilter( + Range(lower_bound=body.start_date_gte, upper_bound=body.start_date_lte), + attribute=TI.start_date, + ) + end_date = RangeFilter( + Range(lower_bound=body.end_date_gte, upper_bound=body.end_date_lte), + attribute=TI.end_date, + ) + duration = RangeFilter( + Range(lower_bound=body.duration_gte, upper_bound=body.duration_lte), + attribute=TI.duration, + ) + state = TIStateFilter(body.state) + pool = TIPoolFilter(body.pool) + queue = TIQueueFilter(body.queue) + executor = TIExecutorFilter(body.executor) + + offset = OffsetFilter(body.page_offset) + limit = LimitFilter(body.page_limit) + + order_by = SortParam( + ["id", "state", "duration", "start_date", "end_date", "map_index"], + TI, + ).set_value(body.order_by) + + base_query = select(TI).join(TI.dag_run) + task_instance_select, total_entries = paginated_select( + base_query, + [ + dag_ids, + dag_run_ids, + task_ids, + logical_date, + start_date, + end_date, + duration, + state, + pool, + queue, + executor, + ], + order_by, + offset, + limit, + session, + ) + + task_instance_select = task_instance_select.options( + joinedload(TI.rendered_task_instance_fields), joinedload(TI.task_instance_note) + ) + + task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( task_instances=[ diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 56d92e828772a7..46940bfd318938 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1103,6 +1103,9 @@ export type DagRunServiceClearDagRunMutationResult = Awaited< export type PoolServicePostPoolMutationResult = Awaited< ReturnType >; +export type TaskInstanceServiceGetTaskInstancesBatchMutationResult = Awaited< + ReturnType +>; export type VariableServicePostVariableMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index b2556a2cd5986e..a96b09e12a7990 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -40,6 +40,7 @@ import { DagWarningType, PoolPatchBody, PoolPostBody, + TaskInstancesBatchBody, VariableBody, } from "../requests/types.gen"; import * as Common from "./common"; @@ -2000,6 +2001,45 @@ export const usePoolServicePostPool = < PoolService.postPool({ requestBody }) as unknown as Promise, ...options, }); +/** + * Get Task Instances Batch + * Get list of task instances. + * @param data The data for the request. + * @param data.requestBody + * @returns TaskInstanceCollectionResponse Successful Response + * @throws ApiError + */ +export const useTaskInstanceServiceGetTaskInstancesBatch = < + TData = Common.TaskInstanceServiceGetTaskInstancesBatchMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + requestBody: TaskInstancesBatchBody; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + requestBody: TaskInstancesBatchBody; + }, + TContext + >({ + mutationFn: ({ requestBody }) => + TaskInstanceService.getTaskInstancesBatch({ + requestBody, + }) as unknown as Promise, + ...options, + }); /** * Post Variable * Create a variable. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 1a4ab7f498cae0..e5ac0441a2aaa8 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -3359,6 +3359,236 @@ export const $TaskInstanceStateCount = { description: "TaskInstance serializer for responses.", } as const; +export const $TaskInstancesBatchBody = { + properties: { + dag_ids: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Dag Ids", + }, + dag_run_ids: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Dag Run Ids", + }, + task_ids: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Task Ids", + }, + state: { + anyOf: [ + { + items: { + anyOf: [ + { + $ref: "#/components/schemas/TaskInstanceState", + }, + { + type: "null", + }, + ], + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "State", + }, + logical_date_gte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Logical Date Gte", + }, + logical_date_lte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Logical Date Lte", + }, + start_date_gte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Start Date Gte", + }, + start_date_lte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Start Date Lte", + }, + end_date_gte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "End Date Gte", + }, + end_date_lte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "End Date Lte", + }, + duration_gte: { + anyOf: [ + { + type: "number", + }, + { + type: "null", + }, + ], + title: "Duration Gte", + }, + duration_lte: { + anyOf: [ + { + type: "number", + }, + { + type: "null", + }, + ], + title: "Duration Lte", + }, + pool: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Pool", + }, + queue: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Queue", + }, + executor: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Executor", + }, + page_offset: { + type: "integer", + minimum: 0, + title: "Page Offset", + default: 0, + }, + page_limit: { + type: "integer", + minimum: 0, + title: "Page Limit", + default: 100, + }, + order_by: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Order By", + }, + }, + type: "object", + title: "TaskInstancesBatchBody", + description: "Task Instance body for get batch.", +} as const; + export const $TaskOutletAssetReference = { properties: { dag_id: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 7a5d1389eb53c4..53bb3527d1421c 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -105,6 +105,8 @@ import type { GetMappedTaskInstanceResponse, GetTaskInstancesData, GetTaskInstancesResponse, + GetTaskInstancesBatchData, + GetTaskInstancesBatchResponse, GetTasksData, GetTasksResponse, GetTaskData, @@ -1755,6 +1757,31 @@ export class TaskInstanceService { }, }); } + + /** + * Get Task Instances Batch + * Get list of task instances. + * @param data The data for the request. + * @param data.requestBody + * @returns TaskInstanceCollectionResponse Successful Response + * @throws ApiError + */ + public static getTaskInstancesBatch( + data: GetTaskInstancesBatchData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/list", + body: data.requestBody, + mediaType: "application/json", + errors: { + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } } export class TaskService { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 6d8af631417f11..078699cc0f2b92 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -833,6 +833,30 @@ export type TaskInstanceStateCount = { deferred: number; }; +/** + * Task Instance body for get batch. + */ +export type TaskInstancesBatchBody = { + dag_ids?: Array | null; + dag_run_ids?: Array | null; + task_ids?: Array | null; + state?: Array | null; + logical_date_gte?: string | null; + logical_date_lte?: string | null; + start_date_gte?: string | null; + start_date_lte?: string | null; + end_date_gte?: string | null; + end_date_lte?: string | null; + duration_gte?: number | null; + duration_lte?: number | null; + pool?: Array | null; + queue?: Array | null; + executor?: Array | null; + page_offset?: number; + page_limit?: number; + order_by?: string | null; +}; + /** * Task outlet reference serializer for assets. */ @@ -1423,6 +1447,12 @@ export type GetTaskInstancesData = { export type GetTaskInstancesResponse = TaskInstanceCollectionResponse; +export type GetTaskInstancesBatchData = { + requestBody: TaskInstancesBatchBody; +}; + +export type GetTaskInstancesBatchResponse = TaskInstanceCollectionResponse; + export type GetTasksData = { dagId: string; orderBy?: string; @@ -2837,6 +2867,33 @@ export type $OpenApiTs = { }; }; }; + "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/list": { + post: { + req: GetTaskInstancesBatchData; + res: { + /** + * Successful Response + */ + 200: TaskInstanceCollectionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; "/public/dags/{dag_id}/tasks/": { get: { req: GetTasksData; diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index 12f6e02a1fa1b4..c71c017a63ec0e 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -1127,3 +1127,283 @@ def test_should_respond_dependencies_mapped(self, test_client, session): "print_the_context/0/dependencies", ) assert response.status_code == 200, response.text + + +class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "task_instances, update_extras, payload, expected_ti_count", + [ + pytest.param( + [ + {"queue": "test_queue_1"}, + {"queue": "test_queue_2"}, + {"queue": "test_queue_3"}, + ], + True, + {"queue": ["test_queue_1", "test_queue_2"]}, + 2, + id="test queue filter", + ), + pytest.param( + [ + {"executor": "test_exec_1"}, + {"executor": "test_exec_2"}, + {"executor": "test_exec_3"}, + ], + True, + {"executor": ["test_exec_1", "test_exec_2"]}, + 2, + id="test executor filter", + ), + pytest.param( + [ + {"duration": 100}, + {"duration": 150}, + {"duration": 200}, + ], + True, + {"duration_gte": 100, "duration_lte": 200}, + 3, + id="test duration filter", + ), + pytest.param( + [ + {"logical_date": DEFAULT_DATETIME_1}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3)}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=4)}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5)}, + ], + False, + { + "logical_date_gte": DEFAULT_DATETIME_1.isoformat(), + "logical_date_lte": (DEFAULT_DATETIME_1 + dt.timedelta(days=2)).isoformat(), + }, + 3, + id="with execution date filter", + ), + pytest.param( + [ + {"logical_date": DEFAULT_DATETIME_1}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3)}, + ], + False, + { + "dag_run_ids": ["TEST_DAG_RUN_ID_0", "TEST_DAG_RUN_ID_1"], + }, + 2, + id="test dag run id filter", + ), + pytest.param( + [ + {"logical_date": DEFAULT_DATETIME_1}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3)}, + ], + False, + { + "task_ids": ["print_the_context", "log_sql_query"], + }, + 2, + id="test task id filter", + ), + ], + ) + def test_should_respond_200( + self, test_client, task_instances, update_extras, payload, expected_ti_count, session + ): + self.create_task_instances( + session, + update_extras=update_extras, + task_instances=task_instances, + ) + response = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json=payload, + ) + body = response.json() + assert response.status_code == 200, body + assert expected_ti_count == body["total_entries"] + assert expected_ti_count == len(body["task_instances"]) + + def test_should_respond_200_for_order_by(self, test_client, session): + dag_id = "example_python_operator" + self.create_task_instances( + session, + task_instances=[ + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i + 1))} for i in range(10) + ], + dag_id=dag_id, + ) + + ti_count = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).count() + + # Ascending order + response_asc = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json={"order_by": "start_date", "dag_ids": [dag_id]}, + ) + assert response_asc.status_code == 200, response_asc.json() + assert response_asc.json()["total_entries"] == ti_count + assert len(response_asc.json()["task_instances"]) == ti_count + + # Descending order + response_desc = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json={"order_by": "-start_date", "dag_ids": [dag_id]}, + ) + assert response_desc.status_code == 200, response_desc.json() + assert response_desc.json()["total_entries"] == ti_count + assert len(response_desc.json()["task_instances"]) == ti_count + + # Compare + start_dates_asc = [ti["start_date"] for ti in response_asc.json()["task_instances"]] + assert len(start_dates_asc) == ti_count + start_dates_desc = [ti["start_date"] for ti in response_desc.json()["task_instances"]] + assert len(start_dates_desc) == ti_count + assert start_dates_asc == list(reversed(start_dates_desc)) + + @pytest.mark.parametrize( + "task_instances, payload, expected_ti_count", + [ + pytest.param( + [ + {"task": "test_1"}, + {"task": "test_2"}, + ], + {"dag_ids": ["latest_only"]}, + 2, + id="task_instance properties", + ), + ], + ) + def test_should_respond_200_when_task_instance_properties_are_none( + self, test_client, task_instances, payload, expected_ti_count, session + ): + self.ti_extras.update( + { + "start_date": None, + "end_date": None, + "state": None, + } + ) + self.create_task_instances( + session, + dag_id="latest_only", + task_instances=task_instances, + ) + response = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json=payload, + ) + body = response.json() + assert response.status_code == 200, body + assert expected_ti_count == body["total_entries"] + assert expected_ti_count == len(body["task_instances"]) + + @pytest.mark.parametrize( + "payload, expected_ti, total_ti", + [ + pytest.param( + {"dag_ids": ["example_python_operator", "example_skip_dag"]}, + 17, + 17, + id="with dag filter", + ), + ], + ) + def test_should_respond_200_dag_ids_filter(self, test_client, payload, expected_ti, total_ti, session): + self.create_task_instances(session) + self.create_task_instances(session, dag_id="example_skip_dag") + response = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json=payload, + ) + assert response.status_code == 200 + assert len(response.json()["task_instances"]) == expected_ti + assert response.json()["total_entries"] == total_ti + + def test_should_raise_400_for_no_json(self, test_client): + response = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + ) + assert response.status_code == 422 + assert response.json()["detail"] == [ + { + "input": None, + "loc": ["body"], + "msg": "Field required", + "type": "missing", + }, + ] + + @pytest.mark.parametrize( + "payload, expected", + [ + ({"end_date_lte": "2020-11-10T12:42:39.442973"}, "Input should have timezone info"), + ({"end_date_gte": "2020-11-10T12:42:39.442973"}, "Input should have timezone info"), + ({"start_date_lte": "2020-11-10T12:42:39.442973"}, "Input should have timezone info"), + ({"start_date_gte": "2020-11-10T12:42:39.442973"}, "Input should have timezone info"), + ({"logical_date_gte": "2020-11-10T12:42:39.442973"}, "Input should have timezone info"), + ({"logical_date_lte": "2020-11-10T12:42:39.442973"}, "Input should have timezone info"), + ], + ) + def test_should_raise_400_for_naive_and_bad_datetime(self, test_client, payload, expected, session): + self.create_task_instances(session) + response = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json=payload, + ) + assert response.status_code == 422 + assert expected in str(response.json()["detail"]) + + def test_should_respond_200_for_pagination(self, test_client, session): + dag_id = "example_python_operator" + + self.create_task_instances( + session, + task_instances=[ + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i + 1))} for i in range(10) + ], + dag_id=dag_id, + ) + + # First 5 items + response_batch1 = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json={"page_limit": 5, "page_offset": 0}, + ) + assert response_batch1.status_code == 200, response_batch1.json() + num_entries_batch1 = len(response_batch1.json()["task_instances"]) + assert num_entries_batch1 == 5 + assert len(response_batch1.json()["task_instances"]) == 5 + + # 5 items after that + response_batch2 = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json={"page_limit": 5, "page_offset": 5}, + ) + assert response_batch2.status_code == 200, response_batch2.json() + num_entries_batch2 = len(response_batch2.json()["task_instances"]) + assert num_entries_batch2 > 0 + assert len(response_batch2.json()["task_instances"]) > 0 + + # Match + ti_count = 9 + assert response_batch1.json()["total_entries"] == response_batch2.json()["total_entries"] == ti_count + assert (num_entries_batch1 + num_entries_batch2) == ti_count + assert response_batch1 != response_batch2 + + # default limit and offset + response_batch3 = test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json={}, + ) + + num_entries_batch3 = len(response_batch3.json()["task_instances"]) + assert num_entries_batch3 == ti_count + assert len(response_batch3.json()["task_instances"]) == ti_count