diff --git a/airflow/providers/mysql/operators/mysql.py b/airflow/providers/mysql/operators/mysql.py index 9fa249ee3e80c..18f3d0ccdc444 100644 --- a/airflow/providers/mysql/operators/mysql.py +++ b/airflow/providers/mysql/operators/mysql.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import ast from typing import Dict, Iterable, Mapping, Optional, Union from airflow.models import BaseOperator @@ -37,6 +38,8 @@ class MySqlOperator(BaseOperator): :param mysql_conn_id: Reference to :ref:`mysql connection id `. :type mysql_conn_id: str :param parameters: (optional) the parameters to render the SQL query with. + Template reference are recognized by str ending in '.json' + (templated) :type parameters: dict or iterable :param autocommit: if True, each command is automatically committed. (default value: False) @@ -67,6 +70,11 @@ def __init__( self.parameters = parameters self.database = database + def prepare_template(self) -> None: + """Parse template file for attribute parameters.""" + if isinstance(self.parameters, str): + self.parameters = ast.literal_eval(self.parameters) + def execute(self, context: Dict) -> None: self.log.info('Executing: %s', self.sql) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) diff --git a/tests/providers/mysql/operators/test_mysql.py b/tests/providers/mysql/operators/test_mysql.py index f36fab88b14e4..c31ab73ae9757 100644 --- a/tests/providers/mysql/operators/test_mysql.py +++ b/tests/providers/mysql/operators/test_mysql.py @@ -15,8 +15,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os import unittest from contextlib import closing +from tempfile import NamedTemporaryFile import pytest from parameterized import parameterized @@ -108,3 +110,19 @@ def test_overwrite_schema(self, client): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) except OperationalError as e: assert "Unknown database 'foobar'" in str(e) + + def test_mysql_operator_resolve_parameters_template_json_file(self): + + with NamedTemporaryFile(suffix='.json') as f: + f.write(b"{\n \"foo\": \"{{ ds }}\"}") + f.flush() + template_dir = os.path.dirname(f.name) + template_file = os.path.basename(f.name) + + with DAG("test-dag", start_date=DEFAULT_DATE, template_searchpath=template_dir): + task = MySqlOperator(task_id="op1", parameters=template_file, sql="SELECT 1") + + task.resolve_template_files() + + assert isinstance(task.parameters, dict) + assert task.parameters["foo"] == "{{ ds }}"