diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py b/airflow/providers/amazon/aws/hooks/redshift_sql.py index 33450d61ca40c..8df15a8eead15 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_sql.py +++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py @@ -20,14 +20,19 @@ from typing import TYPE_CHECKING import redshift_connector +from packaging.version import Version from redshift_connector import Connection as RedshiftConnection from sqlalchemy import create_engine from sqlalchemy.engine.url import URL +from airflow import __version__ as AIRFLOW_VERSION from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.common.sql.hooks.sql import DbApiHook +_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") + + if TYPE_CHECKING: from airflow.models.connection import Connection from airflow.providers.openlineage.sqlparser import DatabaseInfo @@ -257,4 +262,6 @@ def get_openlineage_database_dialect(self, connection: Connection) -> str: def get_openlineage_default_schema(self) -> str | None: """Return current schema. This is usually changed with ``SEARCH_PATH`` parameter.""" - return self.get_first("SELECT CURRENT_SCHEMA();")[0] + if _IS_AIRFLOW_2_10_OR_HIGHER: + return self.get_first("SELECT CURRENT_SCHEMA();")[0] + return super().get_openlineage_default_schema() diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index a56bf58884f58..ed217bdc7aa23 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -29,7 +29,9 @@ import attrs from deprecated import deprecated from openlineage.client.utils import RedactMixin +from packaging.version import Version +from airflow import __version__ as AIRFLOW_VERSION from airflow.exceptions import AirflowProviderDeprecationWarning # TODO: move this maybe to Airflow's logic? from airflow.models import DAG, BaseOperator, MappedOperator from airflow.providers.openlineage import conf @@ -57,6 +59,7 @@ log = logging.getLogger(__name__) _NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" +_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") def try_import_from_string(string: str) -> Any: @@ -558,5 +561,7 @@ def normalize_sql(sql: str | Iterable[str]): def should_use_external_connection(hook) -> bool: - # TODO: Add checking overrides - return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"] + # If we're at Airflow 2.10, the execution is process-isolated, so we can safely run those again. + if not _IS_AIRFLOW_2_10_OR_HIGHER: + return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook", "RedshiftSQLHook"] + return True diff --git a/tests/providers/amazon/aws/operators/test_redshift_sql.py b/tests/providers/amazon/aws/operators/test_redshift_sql.py index 586172c3b87a9..003c40e615b70 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_sql.py +++ b/tests/providers/amazon/aws/operators/test_redshift_sql.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, PropertyMock, call, patch import pytest from openlineage.client.facet import ( @@ -31,7 +31,7 @@ from openlineage.client.run import Dataset from airflow.models.connection import Connection -from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook +from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook as OriginalRedshiftSQLHook from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator MOCK_REGION_NAME = "eu-north-1" @@ -40,38 +40,64 @@ class TestRedshiftSQLOpenLineage: @patch.dict("os.environ", AIRFLOW_CONN_AWS_DEFAULT=f"aws://?region_name={MOCK_REGION_NAME}") @pytest.mark.parametrize( - "connection_host, connection_extra, expected_identity", + "connection_host, connection_extra, expected_identity, is_over_210, expected_schemaname", [ # test without a connection host but with a cluster_identifier in connection extra ( None, {"iam": True, "cluster_identifier": "cluster_identifier_from_extra"}, f"cluster_identifier_from_extra.{MOCK_REGION_NAME}", + True, + "database.public", ), # test with a connection host and without a cluster_identifier in connection extra ( "cluster_identifier_from_host.id.my_region.redshift.amazonaws.com", {"iam": True}, "cluster_identifier_from_host.my_region", + True, + "database.public", ), # test with both connection host and cluster_identifier in connection extra ( "cluster_identifier_from_host.x.y", {"iam": True, "cluster_identifier": "cluster_identifier_from_extra"}, f"cluster_identifier_from_extra.{MOCK_REGION_NAME}", + True, + "database.public", ), # test when hostname doesn't match pattern + ("1.2.3.4", {}, "1.2.3.4", True, "database.public"), + # test with Airflow below 2.10 not using Hook connection ( - "1.2.3.4", - {}, - "1.2.3.4", + "cluster_identifier_from_host.id.my_region.redshift.amazonaws.com", + {"iam": True}, + "cluster_identifier_from_host.my_region", + False, + "public", ), ], ) + @patch( + "airflow.providers.amazon.aws.hooks.redshift_sql._IS_AIRFLOW_2_10_OR_HIGHER", + new_callable=PropertyMock, + ) + @patch("airflow.providers.openlineage.utils.utils._IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock) @patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn") def test_execute_openlineage_events( - self, mock_aws_hook_conn, connection_host, connection_extra, expected_identity + self, + mock_aws_hook_conn, + mock_ol_utils, + mock_redshift_sql, + connection_host, + connection_extra, + expected_identity, + is_over_210, + expected_schemaname, + # self, mock_aws_hook_conn, connection_host, connection_extra, expected_identity, is_below_2_10, expected_schemaname ): + mock_ol_utils.__bool__ = lambda x: is_over_210 + mock_redshift_sql.__bool__ = lambda x: is_over_210 DB_NAME = "database" DB_SCHEMA_NAME = "public" @@ -84,14 +110,15 @@ def test_execute_openlineage_events( "DbUser": "IAM:user", } - class RedshiftSQLHookForTests(RedshiftSQLHook): + class RedshiftSQLHook(OriginalRedshiftSQLHook): get_conn = MagicMock(name="conn") get_connection = MagicMock() def get_first(self, *_): + self.log.error("CALLING FIRST") return [f"{DB_NAME}.{DB_SCHEMA_NAME}"] - dbapi_hook = RedshiftSQLHookForTests() + dbapi_hook = RedshiftSQLHook() class RedshiftOperatorForTest(SQLExecuteQueryOperator): def get_db_hook(self): @@ -149,94 +176,97 @@ def get_db_hook(self): dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = rows lineage = op.get_openlineage_facets_on_start() - assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [ - call( - "SELECT SVV_REDSHIFT_COLUMNS.schema_name, " - "SVV_REDSHIFT_COLUMNS.table_name, " - "SVV_REDSHIFT_COLUMNS.column_name, " - "SVV_REDSHIFT_COLUMNS.ordinal_position, " - "SVV_REDSHIFT_COLUMNS.data_type, " - "SVV_REDSHIFT_COLUMNS.database_name \n" - "FROM SVV_REDSHIFT_COLUMNS \n" - "WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' " - "AND SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') " - "OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' " - "AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND " - "SVV_REDSHIFT_COLUMNS.table_name IN ('popular_orders_day_of_week')" - ), - call( - "SELECT SVV_REDSHIFT_COLUMNS.schema_name, " - "SVV_REDSHIFT_COLUMNS.table_name, " - "SVV_REDSHIFT_COLUMNS.column_name, " - "SVV_REDSHIFT_COLUMNS.ordinal_position, " - "SVV_REDSHIFT_COLUMNS.data_type, " - "SVV_REDSHIFT_COLUMNS.database_name \n" - "FROM SVV_REDSHIFT_COLUMNS \n" - "WHERE SVV_REDSHIFT_COLUMNS.schema_name = 'database.public' " - "AND SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')" - ), - ] - + if is_over_210: + assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [ + call( + "SELECT SVV_REDSHIFT_COLUMNS.schema_name, " + "SVV_REDSHIFT_COLUMNS.table_name, " + "SVV_REDSHIFT_COLUMNS.column_name, " + "SVV_REDSHIFT_COLUMNS.ordinal_position, " + "SVV_REDSHIFT_COLUMNS.data_type, " + "SVV_REDSHIFT_COLUMNS.database_name \n" + "FROM SVV_REDSHIFT_COLUMNS \n" + f"WHERE SVV_REDSHIFT_COLUMNS.schema_name = '{expected_schemaname}' " + "AND SVV_REDSHIFT_COLUMNS.table_name IN ('little_table') " + "OR SVV_REDSHIFT_COLUMNS.database_name = 'another_db' " + "AND SVV_REDSHIFT_COLUMNS.schema_name = 'another_schema' AND " + "SVV_REDSHIFT_COLUMNS.table_name IN ('popular_orders_day_of_week')" + ), + call( + "SELECT SVV_REDSHIFT_COLUMNS.schema_name, " + "SVV_REDSHIFT_COLUMNS.table_name, " + "SVV_REDSHIFT_COLUMNS.column_name, " + "SVV_REDSHIFT_COLUMNS.ordinal_position, " + "SVV_REDSHIFT_COLUMNS.data_type, " + "SVV_REDSHIFT_COLUMNS.database_name \n" + "FROM SVV_REDSHIFT_COLUMNS \n" + f"WHERE SVV_REDSHIFT_COLUMNS.schema_name = '{expected_schemaname}' " + "AND SVV_REDSHIFT_COLUMNS.table_name IN ('Test_table')" + ), + ] + else: + assert dbapi_hook.get_conn.return_value.cursor.return_value.execute.mock_calls == [] expected_namespace = f"redshift://{expected_identity}:5439" - assert lineage.inputs == [ - Dataset( - namespace=expected_namespace, - name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaField(name="order_day_of_week", type="varchar"), - SchemaField(name="order_placed_on", type="timestamp"), - SchemaField(name="orders_placed", type="int4"), - ] - ) - }, - ), - Dataset( - namespace=expected_namespace, - name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaField(name="order_day_of_week", type="varchar"), - SchemaField(name="additional_constant", type="varchar"), - ] - ) - }, - ), - ] - assert lineage.outputs == [ - Dataset( - namespace=expected_namespace, - name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table", - facets={ - "schema": SchemaDatasetFacet( - fields=[ - SchemaField(name="order_day_of_week", type="varchar"), - SchemaField(name="order_placed_on", type="timestamp"), - SchemaField(name="orders_placed", type="int4"), - SchemaField(name="additional_constant", type="varchar"), - ] - ), - "columnLineage": ColumnLineageDatasetFacet( - fields={ - "additional_constant": ColumnLineageDatasetFacetFieldsAdditional( - inputFields=[ - ColumnLineageDatasetFacetFieldsAdditionalInputFields( - namespace=expected_namespace, - name="database.public.little_table", - field="additional_constant", - ) - ], - transformationDescription="", - transformationType="", - ) - } - ), - }, - ) - ] + if is_over_210: + assert lineage.inputs == [ + Dataset( + namespace=expected_namespace, + name=f"{ANOTHER_DB_NAME}.{ANOTHER_DB_SCHEMA}.popular_orders_day_of_week", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaField(name="order_day_of_week", type="varchar"), + SchemaField(name="order_placed_on", type="timestamp"), + SchemaField(name="orders_placed", type="int4"), + ] + ) + }, + ), + Dataset( + namespace=expected_namespace, + name=f"{DB_NAME}.{DB_SCHEMA_NAME}.little_table", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaField(name="order_day_of_week", type="varchar"), + SchemaField(name="additional_constant", type="varchar"), + ] + ) + }, + ), + ] + assert lineage.outputs == [ + Dataset( + namespace=expected_namespace, + name=f"{DB_NAME}.{DB_SCHEMA_NAME}.test_table", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaField(name="order_day_of_week", type="varchar"), + SchemaField(name="order_placed_on", type="timestamp"), + SchemaField(name="orders_placed", type="int4"), + SchemaField(name="additional_constant", type="varchar"), + ] + ), + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "additional_constant": ColumnLineageDatasetFacetFieldsAdditional( + inputFields=[ + ColumnLineageDatasetFacetFieldsAdditionalInputFields( + namespace=expected_namespace, + name="database.public.little_table", + field="additional_constant", + ) + ], + transformationDescription="", + transformationType="", + ) + } + ), + }, + ) + ] assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}