Skip to content

Commit

Permalink
Move test to accomodate function move
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Sep 18, 2024
1 parent afc116a commit 96c89c5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 41 deletions.
64 changes: 64 additions & 0 deletions tests/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import warnings

from sqlalchemy.exc import SAWarning

from airflow.dag_processing.collection import _get_latest_runs_stmt


def test_statement_latest_runs_one_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = _get_latest_runs_stmt(["fake-dag"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
"SELECT max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))",
]
assert actual == expected, compiled_stmt


def test_statement_latest_runs_many_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = _get_latest_runs_stmt(["fake-dag-1", "fake-dag-2"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
"max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
"AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_execution_date",
]
assert actual == expected, compiled_stmt
41 changes: 0 additions & 41 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import os
import pickle
import re
import warnings
import weakref
from datetime import timedelta
from importlib import reload
Expand All @@ -37,7 +36,6 @@
import pytest
import time_machine
from sqlalchemy import inspect, select
from sqlalchemy.exc import SAWarning

from airflow import settings
from airflow.configuration import conf
Expand Down Expand Up @@ -3992,42 +3990,3 @@ def test_validate_setup_teardown_trigger_rule(self):
Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS."
):
dag.validate_setup_teardown()


def test_statement_latest_runs_one_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
"SELECT max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))",
]
assert actual == expected, compiled_stmt


def test_statement_latest_runs_many_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
"max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
"AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_execution_date",
]
assert actual == expected, compiled_stmt

0 comments on commit 96c89c5

Please sign in to comment.