Skip to content

Commit

Permalink
Merge branch 'main' into fix-oss-handler-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
EricGao888 authored May 15, 2024
2 parents 4580db0 + f411c14 commit 6949583
Show file tree
Hide file tree
Showing 45 changed files with 2,484 additions and 1,349 deletions.
69 changes: 69 additions & 0 deletions airflow/api_connexion/endpoints/dag_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from http import HTTPStatus
from typing import TYPE_CHECKING, Sequence

from flask import Response, current_app
from itsdangerous import BadSignature, URLSafeSerializer
from sqlalchemy import exc, select

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
from airflow.auth.managers.models.resource_details import DagDetails
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagPriorityParsingRequest
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.extensions.init_auth_manager import get_auth_manager

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest


@security.requires_access_dag("PUT")
@provide_session
def reparse_dag_file(*, file_token: str, session: Session = NEW_SESSION) -> Response:
"""Request re-parsing a DAG file."""
secret_key = current_app.config["SECRET_KEY"]
auth_s = URLSafeSerializer(secret_key)
try:
path = auth_s.loads(file_token)
except BadSignature:
raise NotFound("File not found")

requests: Sequence[IsAuthorizedDagRequest] = [
{"method": "PUT", "details": DagDetails(id=dag_id)}
for dag_id in session.scalars(select(DagModel.dag_id).where(DagModel.fileloc == path))
]
if not requests:
raise NotFound("File not found")

# Check if user has read access to all the DAGs defined in the file
if not get_auth_manager().batch_is_authorized_dag(requests):
raise PermissionDenied()

parsing_request = DagPriorityParsingRequest(fileloc=path)
session.add(parsing_request)
try:
session.commit()
except exc.IntegrityError:
session.rollback()
return Response("Duplicate request", HTTPStatus.CREATED)
return Response(status=HTTPStatus.CREATED)
22 changes: 22 additions & 0 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,27 @@ paths:
"404":
$ref: "#/components/responses/NotFound"

/parseDagFile/{file_token}:
parameters:
- $ref: "#/components/parameters/FileToken"

put:
summary: Request re-parsing of a DAG file
description: >
Request re-parsing of existing DAG files using a file token.
x-openapi-router-controller: airflow.api_connexion.endpoints.dag_parsing
operationId: reparse_dag_file
tags: [ DAG ]
responses:
"201":
description: Success.
"401":
$ref: "#/components/responses/Unauthenticated"
"403":
$ref: "#/components/responses/PermissionDenied"
"404":
$ref: "#/components/responses/NotFound"

/datasets/queuedEvent/{uri}:
parameters:
- $ref: "#/components/parameters/DatasetURI"
Expand Down Expand Up @@ -3159,6 +3180,7 @@ components:
*New in version 2.5.0*
nullable: true


UpdateDagRunState:
type: object
description: |
Expand Down
4 changes: 2 additions & 2 deletions airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,14 +1860,14 @@ class GroupCommand(NamedTuple):
"(created by KubernetesExecutor/KubernetesPodOperator) "
"in evicted/failed/succeeded/pending states"
),
func=lazy_load_command("airflow.cli.commands.kubernetes_command.cleanup_pods"),
func=lazy_load_command("airflow.providers.cncf.kubernetes.cli.kubernetes_command.cleanup_pods"),
args=(ARG_NAMESPACE, ARG_MIN_PENDING_MINUTES, ARG_VERBOSE),
),
ActionCommand(
name="generate-dag-yaml",
help="Generate YAML files for all tasks in DAG. Useful for debugging tasks without "
"launching into a cluster",
func=lazy_load_command("airflow.cli.commands.kubernetes_command.generate_pod_yaml"),
func=lazy_load_command("airflow.providers.cncf.kubernetes.cli.kubernetes_command.generate_pod_yaml"),
args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_OUTPUT_PATH, ARG_VERBOSE),
),
)
Expand Down
7 changes: 7 additions & 0 deletions airflow/cli/commands/kubernetes_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import os
import sys
import warnings
from datetime import datetime, timedelta

from kubernetes import client
Expand All @@ -36,6 +37,12 @@
from airflow.utils.cli import get_dag
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

warnings.warn(
"Use kubernetes command from providers package, Use cncf.kubernetes provider >= 8.2.1",
DeprecationWarning,
stacklevel=2,
)


@cli_utils.action_cli
@providers_configuration_loaded
Expand Down
20 changes: 20 additions & 0 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from airflow.configuration import conf
from airflow.dag_processing.processor import DagFileProcessorProcess
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagPriorityParsingRequest
from airflow.models.dagwarning import DagWarning
from airflow.models.db_callback_request import DbCallbackRequest
from airflow.models.errors import ParseImportError
Expand Down Expand Up @@ -616,6 +617,7 @@ def _run_parsing_loop(self):
elif refreshed_dag_dir:
self.add_new_file_path_to_queue()

self._refresh_requested_filelocs()
self.start_new_processes()

# Update number of loop iteration.
Expand Down Expand Up @@ -728,6 +730,24 @@ def _add_callback_to_queue(self, request: CallbackRequest):
self._add_paths_to_queue([request.full_filepath], True)
Stats.incr("dag_processing.other_callback_count")

@provide_session
def _refresh_requested_filelocs(self, session=NEW_SESSION) -> None:
"""Refresh filepaths from dag dir as requested by users via APIs."""
# Get values from DB table
requests = session.scalars(select(DagPriorityParsingRequest))
for request in requests:
# Check if fileloc is in valid file paths. Parsing any
# filepaths can be a security issue.
if request.fileloc in self._file_paths:
# Try removing the fileloc if already present
try:
self._file_path_queue.remove(request.fileloc)
except ValueError:
pass
# enqueue fileloc to the start of the queue.
self._file_path_queue.appendleft(request.fileloc)
session.delete(request)

def _refresh_dag_dir(self) -> bool:
"""Refresh file paths from dag dir if we haven't done it for too long."""
now = timezone.utcnow()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Added DagPriorityParsingRequest table.
Revision ID: c4602ba06b4b
Revises: 677fdbb7fc54
Create Date: 2024-04-17 17:12:05.473889
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "c4602ba06b4b"
down_revision = "677fdbb7fc54"
branch_labels = None
depends_on = None
airflow_version = "2.10.0"


def upgrade():
"""Apply Added DagPriorityParsingRequest table."""
op.create_table(
"dag_priority_parsing_request",
sa.Column("id", sa.String(length=32), nullable=False),
sa.Column("fileloc", sa.String(length=2000), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("dag_priority_parsing_request_pkey")),
)


def downgrade():
"""Unapply Added DagPriorityParsingRequest table."""
op.drop_table("dag_priority_parsing_request")
35 changes: 35 additions & 0 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import hashlib
import importlib
import importlib.machinery
import importlib.util
Expand All @@ -30,6 +31,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, NamedTuple

from sqlalchemy import (
Column,
String,
)
from sqlalchemy.exc import OperationalError
from tabulate import tabulate

Expand All @@ -43,6 +48,7 @@
AirflowDagDuplicatedIdException,
RemovedInAirflow3Warning,
)
from airflow.models.base import Base
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
Expand Down Expand Up @@ -727,3 +733,32 @@ def _sync_perm_for_dag(cls, dag: DAG, session: Session = NEW_SESSION):

security_manager = ApplessAirflowSecurityManager(session=session)
security_manager.sync_perm_for_dag(root_dag_id, dag.access_control)


def generate_md5_hash(context):
fileloc = context.get_current_parameters()["fileloc"]
return hashlib.md5(fileloc.encode()).hexdigest()


class DagPriorityParsingRequest(Base):
"""Model to store the dag parsing requests that will be prioritized when parsing files."""

__tablename__ = "dag_priority_parsing_request"

# Adding a unique constraint to fileloc results in the creation of an index and we have a limitation
# on the size of the string we can use in the index for MySQL DB. We also have to keep the fileloc
# size consistent with other tables. This is a workaround to enforce the unique constraint.
id = Column(String(32), primary_key=True, default=generate_md5_hash, onupdate=generate_md5_hash)

# The location of the file containing the DAG object
# Note: Do not depend on fileloc pointing to a file; in the case of a
# packaged DAG, it will point to the subpath of the DAG within the
# associated zip.
fileloc = Column(String(2000), nullable=False)

def __init__(self, fileloc: str) -> None:
super().__init__()
self.fileloc = fileloc

def __repr__(self) -> str:
return f"<DagPriorityParsingRequest: fileloc={self.fileloc}>"
50 changes: 47 additions & 3 deletions airflow/providers/amazon/aws/executors/batch/batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

from __future__ import annotations

import contextlib
import time
from collections import defaultdict, deque
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Sequence

from botocore.exceptions import ClientError, NoCredentialsError

Expand All @@ -34,11 +35,12 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.batch.boto_schema import (
BatchDescribeJobsResponseSchema,
BatchSubmitJobResponseSchema,
Expand Down Expand Up @@ -306,14 +308,20 @@ def attempt_submit_jobs(self):
self.pending_jobs.append(batch_job)
else:
# Success case
job_id = submit_job_response["job_id"]
self.active_workers.add_job(
job_id=submit_job_response["job_id"],
job_id=job_id,
airflow_task_key=key,
airflow_cmd=cmd,
queue=queue,
exec_config=exec_config,
attempt_number=attempt_number,
)
with contextlib.suppress(AttributeError):
# TODO: Remove this when min_airflow_version is 2.10.0 or higher in Amazon provider.
# running_state is added in Airflow 2.10 and only needed to support task adoption
# (an optional executor feature).
self.running_state(key, job_id)
if failure_reasons:
self.log.error(
"Pending Batch jobs failed to launch for the following reasons: %s. Retrying later.",
Expand Down Expand Up @@ -418,3 +426,39 @@ def _load_submit_kwargs() -> dict:
" and value should be NULL or empty."
)
return submit_kwargs

def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Adopt task instances which have an external_executor_id (the Batch job ID).
Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
"""
with Stats.timer("batch_executor.adopt_task_instances.duration"):
adopted_tis: list[TaskInstance] = []

if job_ids := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
batch_jobs = self._describe_jobs(job_ids)

for batch_job in batch_jobs:
ti = next(ti for ti in tis if ti.external_executor_id == batch_job.job_id)
self.active_workers.add_job(
job_id=batch_job.job_id,
airflow_task_key=ti.key,
airflow_cmd=ti.command_as_list(),
queue=ti.queue,
exec_config=ti.executor_config,
attempt_number=ti.prev_attempted_tries,
)
adopted_tis.append(ti)

if adopted_tis:
tasks = [f"{task} in state {task.state}" for task in adopted_tis]
task_instance_str = "\n\t".join(tasks)
self.log.info(
"Adopted the following %d tasks from a dead executor:\n\t%s",
len(adopted_tis),
task_instance_str,
)

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis
Loading

0 comments on commit 6949583

Please sign in to comment.