Skip to content

Commit

Permalink
Restore stability and unquarantine all test_scheduler_job tests (#19860)
Browse files Browse the repository at this point in the history
* Restore stability and unquarantine all test_scheduler_job tests

The scheduler job tests were pretty flaky and some of them were
quarantined already (especially the query count). This PR improves
the stability in the following ways:

* clean the database between tests for TestSchedulerJob to avoid
  side effects
* forces UTC timezone in tests where date missed timezone specs
* updates number of queries expected in the query count tests
* stabilizes the sequence of retrieval of tasks in case tests
  depended on it
* adds more stack trace levels (5) to compare where extra
  methods were called.
* increase number of scheduler runs where it was needed
* add session.flush() where it was missing
* add requirement to have serialized dags ready when needed
* increase dagruns number to process where we could have
  some "too slow" tests comparing to fast processing of
  dag runs.

Hopefully:

* Fixes: #18777
* Fixes: #17291
* Fixes: #17224
* Fixes: #15255
* Fixes: #15085

* Update tests/jobs/test_scheduler_job.py

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
(cherry picked from commit 9b277db)
  • Loading branch information
potiuk authored and jedcunningham committed Jan 27, 2022
1 parent 3fa1535 commit ad00e8e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 83 deletions.
142 changes: 63 additions & 79 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import shutil
from datetime import timedelta
from tempfile import mkdtemp
from typing import Generator, Optional
from unittest import mock
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -66,6 +67,7 @@
)
from tests.test_utils.mock_executor import MockExecutor
from tests.test_utils.mock_operators import CustomOperator
from tests.utils.test_timezone import UTC

ROOT_FOLDER = os.path.realpath(
os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
Expand Down Expand Up @@ -110,24 +112,32 @@ def clean_db():
# The tests expect DAGs to be fully loaded here via setUpClass method below

@pytest.fixture(autouse=True)
def set_instance_attrs(self, dagbag):
self.dagbag = dagbag
def per_test(self) -> Generator:
self.clean_db()
self.scheduler_job = None

yield

if self.scheduler_job and self.scheduler_job.processor_agent:
self.scheduler_job.processor_agent.end()
self.scheduler_job = None
self.clean_db()

@pytest.fixture(autouse=True)
def set_instance_attrs(self, dagbag) -> Generator:
self.dagbag = dagbag
# Speed up some tests by not running the tasks, just look at what we
# enqueue!
self.null_exec = MockExecutor()
self.null_exec: Optional[MockExecutor] = MockExecutor()

# Since we don't want to store the code for the DAG defined in this file
with patch('airflow.dag_processing.manager.SerializedDagModel.remove_deleted_dags'), patch(
'airflow.models.dag.DagCode.bulk_sync_to_db'
):
yield

if self.scheduler_job and self.scheduler_job.processor_agent:
self.scheduler_job.processor_agent.end()
self.scheduler_job = None
self.clean_db()
self.null_exec = None
self.dagbag = None

def test_is_alive(self):
self.scheduler_job = SchedulerJob(None, heartrate=10, state=State.RUNNING)
Expand Down Expand Up @@ -166,7 +176,6 @@ def run_single_scheduler_loop_with_no_dags(self, dags_folder):
self.scheduler_job.heartrate = 0
self.scheduler_job.run()

@pytest.mark.quarantined
def test_no_orphan_process_will_be_left(self):
empty_dir = mkdtemp()
current_process = psutil.Process()
Expand Down Expand Up @@ -443,15 +452,20 @@ def test_find_executable_task_instances_pool(self, dag_maker):
task_id_2 = 'dummydummy'
session = settings.Session()
with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session):
DummyOperator(task_id=task_id_1, pool='a')
DummyOperator(task_id=task_id_2, pool='b')
DummyOperator(task_id=task_id_1, pool='a', priority_weight=2)
DummyOperator(task_id=task_id_2, pool='b', priority_weight=1)

self.scheduler_job = SchedulerJob(subdir=os.devnull)

dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)

tis = dr1.task_instances + dr2.task_instances
tis = [
dr1.get_task_instance(task_id_1, session=session),
dr1.get_task_instance(task_id_2, session=session),
dr2.get_task_instance(task_id_1, session=session),
dr2.get_task_instance(task_id_2, session=session),
]
for ti in tis:
ti.state = State.SCHEDULED
session.merge(ti)
Expand Down Expand Up @@ -1705,13 +1719,14 @@ def test_scheduler_start_date(self):
session.commit()
assert [] == self.null_exec.sorted_tasks

@pytest.mark.quarantined
def test_scheduler_task_start_date(self):
"""
Test that the scheduler respects task start dates that are different from DAG start dates
"""

dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False)
dagbag = DagBag(
dag_folder=os.path.join(settings.DAGS_FOLDER, "test_scheduler_dags.py"), include_examples=False
)
dag_id = 'test_task_start_date_scheduling'
dag = self.dagbag.get_dag(dag_id)
dag.is_paused_upon_creation = False
Expand All @@ -1724,15 +1739,15 @@ def test_scheduler_task_start_date(self):

dagbag.sync_to_db()

self.scheduler_job = SchedulerJob(executor=self.null_exec, subdir=dag.fileloc, num_runs=2)
self.scheduler_job = SchedulerJob(executor=self.null_exec, subdir=dag.fileloc, num_runs=3)
self.scheduler_job.run()

session = settings.Session()
tiq = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id)
ti1s = tiq.filter(TaskInstance.task_id == 'dummy1').all()
ti2s = tiq.filter(TaskInstance.task_id == 'dummy2').all()
assert len(ti1s) == 0
assert len(ti2s) == 2
assert len(ti2s) >= 2
for task in ti2s:
assert task.state == State.SUCCESS

Expand All @@ -1757,31 +1772,6 @@ def test_scheduler_multiprocessing(self):
session = settings.Session()
assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 0

@conf_vars({("core", "mp_start_method"): "spawn"})
def test_scheduler_multiprocessing_with_spawn_method(self):
"""
Test that the scheduler can successfully queue multiple dags in parallel
when using "spawn" mode of multiprocessing. (Fork is default on Linux and older OSX)
"""
dag_ids = ['test_start_date_scheduling', 'test_dagrun_states_success']
for dag_id in dag_ids:
dag = self.dagbag.get_dag(dag_id)
dag.clear()

self.scheduler_job = SchedulerJob(
executor=self.null_exec,
subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'),
num_runs=1,
)

self.scheduler_job.run()

# zero tasks ran
dag_id = 'test_start_date_scheduling'
with create_session() as session:
assert session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).count() == 0

@pytest.mark.quarantined
def test_scheduler_verify_pool_full(self, dag_maker):
"""
Test task instances not queued when pool is full
Expand All @@ -1808,6 +1798,7 @@ def test_scheduler_verify_pool_full(self, dag_maker):
self.scheduler_job._schedule_dag_run(dr, session)
dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.RUNNING)
self.scheduler_job._schedule_dag_run(dr, session)
session.flush()
task_instances_list = self.scheduler_job._executable_task_instances_to_queued(
max_tis=32, session=session
)
Expand Down Expand Up @@ -1858,7 +1849,6 @@ def _create_dagruns():
# As tasks require 2 slots, only 3 can fit into 6 available
assert len(task_instances_list) == 3

@pytest.mark.quarantined
def test_scheduler_keeps_scheduling_pool_full(self, dag_maker):
"""
Test task instances in a pool that isn't full keep getting scheduled even when a pool is full.
Expand Down Expand Up @@ -1897,16 +1887,17 @@ def test_scheduler_keeps_scheduling_pool_full(self, dag_maker):

def _create_dagruns(dag: DAG):
next_info = dag.next_dagrun_info(None)
for _ in range(5):
assert next_info is not None
for _ in range(30):
yield dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=next_info.logical_date,
data_interval=next_info.data_interval,
state=State.RUNNING,
state=DagRunState.RUNNING,
)
next_info = dag.next_dagrun_info(next_info.data_interval)

# Create 5 dagruns for each DAG.
# Create 30 dagruns for each DAG.
# To increase the chances the TIs from the "full" pool will get retrieved first, we schedule all
# TIs from the first dag first.
for dr in _create_dagruns(dag_d1):
Expand Down Expand Up @@ -2048,7 +2039,6 @@ def test_verify_integrity_if_dag_not_changed(self, dag_maker):
session.rollback()
session.close()

@pytest.mark.quarantined
def test_verify_integrity_if_dag_changed(self, dag_maker):
# CleanUp
with create_session() as session:
Expand Down Expand Up @@ -2113,7 +2103,6 @@ def test_verify_integrity_if_dag_changed(self, dag_maker):
session.rollback()
session.close()

@pytest.mark.quarantined
@pytest.mark.need_serialized_dag
def test_retry_still_in_executor(self, dag_maker):
"""
Expand Down Expand Up @@ -2889,6 +2878,8 @@ def complete_one_dagrun():
ti.state = State.SUCCESS
session.flush()

self.clean_db()

with dag_maker(max_active_runs=3, session=session) as dag:
# Need to use something that doesn't immediately get marked as success by the scheduler
BashOperator(task_id='task', bash_command='true')
Expand All @@ -2906,13 +2897,7 @@ def complete_one_dagrun():
# Pre-condition
assert DagRun.active_runs_of_dags(session=session) == {'test_dag': 3}

assert model.next_dagrun == timezone.convert_to_utc(
timezone.DateTime(
2016,
1,
3,
)
)
assert model.next_dagrun == timezone.DateTime(2016, 1, 3, tzinfo=UTC)
assert model.next_dagrun_create_after is None

complete_one_dagrun()
Expand Down Expand Up @@ -3423,8 +3408,6 @@ def test_task_with_upstream_skip_process_task_instances():
assert tis[dummy3.task_id].state == State.SKIPPED


# TODO(potiuk): unquarantine me where we get rid of those pesky 195 -> 196 problem!
@pytest.mark.quarantined
class TestSchedulerJobQueriesCount:
"""
These tests are designed to detect changes in the number of queries for
Expand Down Expand Up @@ -3456,9 +3439,9 @@ def per_test(self) -> None:
@pytest.mark.parametrize(
"expected_query_count, dag_count, task_count",
[
(20, 1, 1), # One DAG with one task per DAG file.
(20, 1, 5), # One DAG with five tasks per DAG file.
(83, 10, 10), # 10 DAGs with 10 tasks per DAG file.
(21, 1, 1), # One DAG with one task per DAG file.
(21, 1, 5), # One DAG with five tasks per DAG file.
(93, 10, 10), # 10 DAGs with 10 tasks per DAG file.
],
)
def test_execute_queries_count_with_harvested_dags(self, expected_query_count, dag_count, task_count):
Expand Down Expand Up @@ -3519,33 +3502,33 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d
# One DAG with one task per DAG file.
([10, 10, 10, 10], 1, 1, "1d", "None", "no_structure"),
([10, 10, 10, 10], 1, 1, "1d", "None", "linear"),
([23, 13, 13, 13], 1, 1, "1d", "@once", "no_structure"),
([23, 13, 13, 13], 1, 1, "1d", "@once", "linear"),
([23, 24, 26, 28], 1, 1, "1d", "30m", "no_structure"),
([23, 24, 26, 28], 1, 1, "1d", "30m", "linear"),
([23, 24, 26, 28], 1, 1, "1d", "30m", "binary_tree"),
([23, 24, 26, 28], 1, 1, "1d", "30m", "star"),
([23, 24, 26, 28], 1, 1, "1d", "30m", "grid"),
([24, 14, 14, 14], 1, 1, "1d", "@once", "no_structure"),
([24, 14, 14, 14], 1, 1, "1d", "@once", "linear"),
([24, 26, 29, 32], 1, 1, "1d", "30m", "no_structure"),
([24, 26, 29, 32], 1, 1, "1d", "30m", "linear"),
([24, 26, 29, 32], 1, 1, "1d", "30m", "binary_tree"),
([24, 26, 29, 32], 1, 1, "1d", "30m", "star"),
([24, 26, 29, 32], 1, 1, "1d", "30m", "grid"),
# One DAG with five tasks per DAG file.
([10, 10, 10, 10], 1, 5, "1d", "None", "no_structure"),
([10, 10, 10, 10], 1, 5, "1d", "None", "linear"),
([23, 13, 13, 13], 1, 5, "1d", "@once", "no_structure"),
([24, 14, 14, 14], 1, 5, "1d", "@once", "linear"),
([23, 24, 26, 28], 1, 5, "1d", "30m", "no_structure"),
([24, 26, 29, 32], 1, 5, "1d", "30m", "linear"),
([24, 26, 29, 32], 1, 5, "1d", "30m", "binary_tree"),
([24, 26, 29, 32], 1, 5, "1d", "30m", "star"),
([24, 26, 29, 32], 1, 5, "1d", "30m", "grid"),
([24, 14, 14, 14], 1, 5, "1d", "@once", "no_structure"),
([25, 15, 15, 15], 1, 5, "1d", "@once", "linear"),
([24, 26, 29, 32], 1, 5, "1d", "30m", "no_structure"),
([25, 28, 32, 36], 1, 5, "1d", "30m", "linear"),
([25, 28, 32, 36], 1, 5, "1d", "30m", "binary_tree"),
([25, 28, 32, 36], 1, 5, "1d", "30m", "star"),
([25, 28, 32, 36], 1, 5, "1d", "30m", "grid"),
# 10 DAGs with 10 tasks per DAG file.
([10, 10, 10, 10], 10, 10, "1d", "None", "no_structure"),
([10, 10, 10, 10], 10, 10, "1d", "None", "linear"),
([95, 28, 28, 28], 10, 10, "1d", "@once", "no_structure"),
([105, 41, 41, 41], 10, 10, "1d", "@once", "linear"),
([95, 99, 99, 99], 10, 10, "1d", "30m", "no_structure"),
([105, 125, 125, 125], 10, 10, "1d", "30m", "linear"),
([105, 119, 119, 119], 10, 10, "1d", "30m", "binary_tree"),
([105, 119, 119, 119], 10, 10, "1d", "30m", "star"),
([105, 119, 119, 119], 10, 10, "1d", "30m", "grid"),
([105, 38, 38, 38], 10, 10, "1d", "@once", "no_structure"),
([115, 51, 51, 51], 10, 10, "1d", "@once", "linear"),
([105, 119, 119, 119], 10, 10, "1d", "30m", "no_structure"),
([115, 145, 145, 145], 10, 10, "1d", "30m", "linear"),
([115, 139, 139, 139], 10, 10, "1d", "30m", "binary_tree"),
([115, 139, 139, 139], 10, 10, "1d", "30m", "star"),
([115, 139, 139, 139], 10, 10, "1d", "30m", "grid"),
],
)
def test_process_dags_queries_count(
Expand Down Expand Up @@ -3669,6 +3652,7 @@ def test_should_mark_dummy_task_as_success(self):
assert end_date is None
assert duration is None

@pytest.mark.need_serialized_dag
def test_catchup_works_correctly(self, dag_maker):
"""Test that catchup works correctly"""
session = settings.Session()
Expand Down
12 changes: 8 additions & 4 deletions tests/test_utils/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def after_cursor_execute(self, *args, **kwargs):
and __file__ != f.filename
and ('session.py' not in f.filename and f.name != 'wrapper')
]
stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}:{f.lineno}" for f in stack][-3:])
stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}:{f.lineno}" for f in stack][-5:])
self.result[f"{stack_info}"] += 1


Expand All @@ -75,15 +75,19 @@ def assert_queries_count(expected_count, message_fmt=None):
with count_queries() as result:
yield None

# This is a margin we have for queries - we do not want to change it every time we
# changed queries, but we want to catch cases where we spin out of control
margin = 15

count = sum(result.values())
if expected_count != count:
if count > expected_count + margin:
message_fmt = (
message_fmt
or "The expected number of db queries is {expected_count}. "
or "The expected number of db queries is {expected_count} with extra margin: {margin}. "
"The current number is {current_count}.\n\n"
"Recorded query locations:"
)
message = message_fmt.format(current_count=count, expected_count=expected_count)
message = message_fmt.format(current_count=count, expected_count=expected_count, margin=margin)

for location, count in result.items():
message += f'\n\t{location}:\t{count}'
Expand Down

0 comments on commit ad00e8e

Please sign in to comment.