diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index d3f79c9bbde4d..43131c549a3c3 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -222,6 +222,12 @@ class BigQueryCheckOperator( :param deferrable: Run operator in the deferrable mode. :param poll_interval: (Deferrable mode only) polling period in seconds to check for the status of job. + :param query_params: a list of dictionary containing query parameter types and + values, passed to BigQuery. The structure of dictionary should look like + 'queryParameters' in Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs. + For example, [{ 'name': 'corpus', 'parameterType': { 'type': 'STRING' }, + 'parameterValue': { 'value': 'romeoandjuliet' } }]. (templated) """ template_fields: Sequence[str] = ( @@ -229,6 +235,7 @@ class BigQueryCheckOperator( "gcp_conn_id", "impersonation_chain", "labels", + "query_params", ) template_ext: Sequence[str] = (".sql",) ui_color = BigQueryUIColors.CHECK.value @@ -246,6 +253,7 @@ def __init__( encryption_configuration: dict | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 4.0, + query_params: list | None = None, **kwargs, ) -> None: super().__init__(sql=sql, **kwargs) @@ -257,6 +265,7 @@ def __init__( self.encryption_configuration = encryption_configuration self.deferrable = deferrable self.poll_interval = poll_interval + self.query_params = query_params def _submit_job( self, @@ -265,6 +274,8 @@ def _submit_job( ) -> BigQueryJob: """Submit a new job and get the job id for polling the status using Trigger.""" configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}} + if self.query_params: + configuration["query"]["queryParameters"] = self.query_params self.include_encryption_configuration(configuration, "query") diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index febcfe48712e7..d49e75c95070c 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -24,7 +24,7 @@ import pandas as pd import pytest -from google.cloud.bigquery import DEFAULT_RETRY +from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter from google.cloud.exceptions import Conflict from openlineage.client.facet import ErrorMessageRunFacet, ExternalQueryRunFacet, SqlJobFacet from openlineage.client.run import Dataset @@ -2293,6 +2293,35 @@ def test_bigquery_check_operator_async_finish_before_deferred( mock_defer.assert_not_called() mock_validate_records.assert_called_once_with((1, 2, 3)) + @pytest.mark.db_test + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator._validate_records") + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_bigquery_check_operator_query_parameters_passing( + self, mock_hook, mock_validate_records, create_task_instance_of_operator + ): + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + query_params = [ScalarQueryParameter("test_param", "INT64", 1)] + + mocked_job = MagicMock(job_id=real_job_id, error_result=False) + mocked_job.result.return_value = iter([(1, 2, 3)]) # mock rows generator + mock_hook.return_value.insert_job.return_value = mocked_job + mock_hook.return_value.insert_job.return_value.running.return_value = False + + ti = create_task_instance_of_operator( + BigQueryCheckOperator, + dag_id="dag_id", + task_id="bq_check_operator_query_params_job", + sql="SELECT * FROM any WHERE test_param = @test_param", + location=TEST_DATASET_LOCATION, + deferrable=True, + query_params=query_params, + ) + + ti.task.execute(MagicMock()) + mock_validate_records.assert_called_once_with((1, 2, 3)) + @pytest.mark.db_test @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_bigquery_check_operator_async_finish_with_error_before_deferred(