Skip to content

Commit

Permalink
Remove DAG parsing from StandardTaskRunner (#26750)
Browse files Browse the repository at this point in the history
This makes the starting of StandardTaskRunner faster as the parsing of DAG will now be done once at task_run.
Also removed parsing of example dags when running a task
  • Loading branch information
ephraimbuddy authored Sep 29, 2022
1 parent 2e66d2d commit ce07117
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 91 deletions.
10 changes: 1 addition & 9 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
get_dag,
get_dag_by_deserialization,
get_dag_by_file_location,
get_dag_by_pickle,
get_dags,
Expand Down Expand Up @@ -364,14 +363,7 @@ def task_run(args, dag=None):
print(f'Loading pickle id: {args.pickle}')
dag = get_dag_by_pickle(args.pickle)
elif not dag:
if args.local:
try:
dag = get_dag_by_deserialization(args.dag_id)
except AirflowException:
print(f'DAG {args.dag_id} does not exist in the database, trying to parse the dag_file')
dag = get_dag(args.subdir, args.dag_id)
else:
dag = get_dag(args.subdir, args.dag_id)
dag = get_dag(args.subdir, args.dag_id, include_examples=False)
else:
# Use DAG from parameter
pass
Expand Down
8 changes: 3 additions & 5 deletions airflow/task/task_runner/standard_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class StandardTaskRunner(BaseTaskRunner):
def __init__(self, local_task_job):
super().__init__(local_task_job)
self._rc = None
self.dag = local_task_job.task_instance.task.dag

def start(self):
if CAN_FORK and not self.run_as_user:
Expand Down Expand Up @@ -64,7 +65,6 @@ def _start_by_fork(self):
from airflow import settings
from airflow.cli.cli_parser import get_parser
from airflow.sentry import Sentry
from airflow.utils.cli import get_dag

# Force a new SQLAlchemy session. We can't share open DB handles
# between process. The cli code will re-create this as part of its
Expand Down Expand Up @@ -92,10 +92,8 @@ def _start_by_fork(self):
dag_id=self._task_instance.dag_id,
task_id=self._task_instance.task_id,
):
# parse dag file since `airflow tasks run --local` does not parse dag file
dag = get_dag(args.subdir, args.dag_id)
args.func(args, dag=dag)
return_code = 0
args.func(args, dag=self.dag)
return_code = 0
except Exception as exc:
return_code = 1

Expand Down
19 changes: 6 additions & 13 deletions airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from typing import TYPE_CHECKING, Callable, TypeVar, cast

from airflow import settings
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.utils import cli_action_loggers
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
Expand Down Expand Up @@ -205,7 +206,9 @@ def _search_for_dag_file(val: str | None) -> str | None:
return None


def get_dag(subdir: str | None, dag_id: str) -> DAG:
def get_dag(
subdir: str | None, dag_id: str, include_examples=conf.getboolean('core', 'LOAD_EXAMPLES')
) -> DAG:
"""
Returns DAG of a given dag_id
Expand All @@ -216,28 +219,18 @@ def get_dag(subdir: str | None, dag_id: str) -> DAG:
from airflow.models import DagBag

first_path = process_subdir(subdir)
dagbag = DagBag(first_path)
dagbag = DagBag(first_path, include_examples=include_examples)
if dag_id not in dagbag.dags:
fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER
logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path)
dagbag = DagBag(dag_folder=fallback_path)
dagbag = DagBag(dag_folder=fallback_path, include_examples=include_examples)
if dag_id not in dagbag.dags:
raise AirflowException(
f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse."
)
return dagbag.dags[dag_id]


def get_dag_by_deserialization(dag_id: str) -> DAG:
from airflow.models.serialized_dag import SerializedDagModel

dag_model = SerializedDagModel.get(dag_id)
if dag_model is None:
raise AirflowException(f"Serialized DAG: {dag_id} could not be found")

return dag_model.dag


def get_dags(subdir: str | None, dag_id: str, use_regex: bool = False):
"""Returns DAG(s) matching a given regex or dag_id"""
from airflow.models import DagBag
Expand Down
64 changes: 0 additions & 64 deletions tests/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,38 +159,6 @@ def test_test_filters_secrets(self, capsys):
task_command.task_test(args)
assert capsys.readouterr().out.endswith(f"{not_password}\n")

@mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_run_get_serialized_dag(self, mock_local_job, mock_get_dag_by_deserialization):
"""
Test using serialized dag for local task_run
"""
task_id = self.dag.task_ids[0]
args = [
'tasks',
'run',
'--ignore-all-dependencies',
'--local',
self.dag_id,
task_id,
self.run_id,
]
mock_get_dag_by_deserialization.return_value = SerializedDagModel.get(self.dag_id).dag

task_command.task_run(self.parser.parse_args(args))
mock_local_job.assert_called_once_with(
task_instance=mock.ANY,
mark_success=False,
ignore_all_deps=True,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pickle_id=None,
pool=None,
external_executor_id=None,
)
mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)

def test_cli_test_different_path(self, session):
"""
When thedag processor has a different dags folder
Expand Down Expand Up @@ -265,38 +233,6 @@ def test_cli_test_different_path(self, session):
# verify that the file was in different location when run
assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix()

@mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_run_get_serialized_dag_fallback(self, mock_local_job, mock_get_dag_by_deserialization):
"""
Fallback to parse dag_file when serialized dag does not exist in the db
"""
task_id = self.dag.task_ids[0]
args = [
'tasks',
'run',
'--ignore-all-dependencies',
'--local',
self.dag_id,
task_id,
self.run_id,
]
mock_get_dag_by_deserialization.side_effect = mock.Mock(side_effect=AirflowException('Not found'))

task_command.task_run(self.parser.parse_args(args))
mock_local_job.assert_called_once_with(
task_instance=mock.ANY,
mark_success=False,
ignore_all_deps=True,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pickle_id=None,
pool=None,
external_executor_id=None,
)
mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)

@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_run_with_existing_dag_run_id(self, mock_local_job):
"""
Expand Down

0 comments on commit ce07117

Please sign in to comment.