From 1029aead9de2aa0240e832192b00842ee07d16ac Mon Sep 17 00:00:00 2001 From: Nathaniel Young Date: Fri, 24 May 2024 18:16:14 -0700 Subject: [PATCH 1/3] standardizes template fields for BaseSQLOperator --- airflow/providers/common/sql/operators/sql.py | 27 ++++++++++--------- .../common/sql/operators/test_sql.py | 23 ++++++++++++++++ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index ea791992d5c95..f4598efb093f9 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -127,6 +127,8 @@ class BaseSQLOperator(BaseOperator): conn_id_field = "conn_id" + template_fields: Sequence[str] = ("conn_id", "database", "hook_params") + def __init__( self, *, @@ -220,7 +222,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator): :ref:`howto/operator:SQLExecuteQueryOperator` """ - template_fields: Sequence[str] = ("conn_id", "sql", "parameters", "hook_params") + template_fields: Sequence[str] = ("sql", "parameters", *BaseSQLOperator.template_fields) template_ext: Sequence[str] = (".sql", ".json") template_fields_renderers = {"sql": "sql", "parameters": "json"} ui_color = "#cdaaed" @@ -425,7 +427,7 @@ class SQLColumnCheckOperator(BaseSQLOperator): :ref:`howto/operator:SQLColumnCheckOperator` """ - template_fields: Sequence[str] = ("partition_clause", "table", "sql", "hook_params") + template_fields: Sequence[str] = ("table", "partition_clause", "sql", *BaseSQLOperator.template_fields) template_fields_renderers = {"sql": "sql"} sql_check_template = """ @@ -653,7 +655,7 @@ class SQLTableCheckOperator(BaseSQLOperator): :ref:`howto/operator:SQLTableCheckOperator` """ - template_fields: Sequence[str] = ("partition_clause", "table", "sql", "conn_id", "hook_params") + template_fields: Sequence[str] = ("table", "partition_clause", "sql", *BaseSQLOperator.template_fields) template_fields_renderers = {"sql": "sql"} @@ -769,7 +771,7 @@ class SQLCheckOperator(BaseSQLOperator): :param parameters: (optional) the parameters to render the SQL query with. """ - template_fields: Sequence[str] = ("sql", "hook_params") + template_fields: Sequence[str] = ("sql", *BaseSQLOperator.template_fields) template_ext: Sequence[str] = ( ".hql", ".sql", @@ -815,11 +817,7 @@ class SQLValueCheckOperator(BaseSQLOperator): """ __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"} - template_fields: Sequence[str] = ( - "sql", - "pass_value", - "hook_params", - ) + template_fields: Sequence[str] = ("sql", "pass_value", *BaseSQLOperator.template_fields) template_ext: Sequence[str] = ( ".hql", ".sql", @@ -916,7 +914,7 @@ class SQLIntervalCheckOperator(BaseSQLOperator): """ __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"} - template_fields: Sequence[str] = ("sql1", "sql2", "hook_params") + template_fields: Sequence[str] = ("sql1", "sql2", *BaseSQLOperator.template_fields) template_ext: Sequence[str] = ( ".hql", ".sql", @@ -1044,7 +1042,12 @@ class SQLThresholdCheckOperator(BaseSQLOperator): :param max_threshold: numerical value or max threshold sql to be executed (templated) """ - template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold", "hook_params") + template_fields: Sequence[str] = ( + "sql", + "min_threshold", + "max_threshold", + *BaseSQLOperator.template_fields, + ) template_ext: Sequence[str] = ( ".hql", ".sql", @@ -1142,7 +1145,7 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin): :param parameters: (optional) the parameters to render the SQL query with. """ - template_fields: Sequence[str] = ("sql",) + template_fields: Sequence[str] = ("sql", *BaseSQLOperator.template_fields) template_ext: Sequence[str] = (".sql",) template_fields_renderers = {"sql": "sql"} ui_color = "#a22034" diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py index 85d26c75aede2..105099fda0531 100644 --- a/tests/providers/common/sql/operators/test_sql.py +++ b/tests/providers/common/sql/operators/test_sql.py @@ -29,6 +29,7 @@ from airflow.operators.empty import EmptyOperator from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.common.sql.operators.sql import ( + BaseSQLOperator, BranchSQLOperator, SQLCheckOperator, SQLColumnCheckOperator, @@ -59,6 +60,28 @@ def _get_mock_db_hook(): return MockHook() +class TestBaseSQLOperator: + def _construct_operator(self, **kwargs): + dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1), render_template_as_native_obj=True) + return BaseSQLOperator( + task_id="test_task", + conn_id="{{ conn_id }}", + database="{{ database }}", + hook_params="{{ hook_params }}", + **kwargs, + dag=dag, + ) + + def test_templated_fields(self): + operator = self._construct_operator() + operator.render_template_fields( + {"conn_id": "my_conn_id", "database": "my_database", "hook_params": {"key": "value"}} + ) + assert operator.conn_id == "my_conn_id" + assert operator.database == "my_database" + assert operator.hook_params == {"key": "value"} + + class TestSQLExecuteQueryOperator: def _construct_operator(self, sql, **kwargs): dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1)) From e3a0f4a425b7fe6b07e0d531846785255c286e33 Mon Sep 17 00:00:00 2001 From: Nathaniel Young Date: Fri, 24 May 2024 20:54:41 -0700 Subject: [PATCH 2/3] adds template_fields sequence string type --- airflow/providers/common/sql/operators/sql.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/providers/common/sql/operators/sql.pyi b/airflow/providers/common/sql/operators/sql.pyi index fc0789928077a..f8fa23c37edc7 100644 --- a/airflow/providers/common/sql/operators/sql.pyi +++ b/airflow/providers/common/sql/operators/sql.pyi @@ -54,6 +54,7 @@ def parse_boolean(val: str) -> str | bool: ... class BaseSQLOperator(BaseOperator): conn_id_field: str + template_fields: Sequence[str] conn_id: Incomplete database: Incomplete hook_params: Incomplete From 23d782dfcc7e472a793920b5c7e75ba95b1608be Mon Sep 17 00:00:00 2001 From: Nathaniel Young Date: Fri, 24 May 2024 21:20:31 -0700 Subject: [PATCH 3/3] fixes hook params check in init --- airflow/providers/common/sql/operators/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index f4598efb093f9..d50a6bf0f5926 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -141,7 +141,7 @@ def __init__( super().__init__(**kwargs) self.conn_id = conn_id self.database = database - self.hook_params = {} if hook_params is None else hook_params + self.hook_params = hook_params or {} self.retry_on_failure = retry_on_failure @cached_property