Skip to content

Commit

Permalink
Add support for query parameters to BigQueryCheckOperator (#40556) (#…
Browse files Browse the repository at this point in the history
…40558)

* Add support for query parameters to BigQueryCheckOperator (#40556)

Remove unnecessary space

* Add a unit test for BigQueryCheckOperator query params; fix missing 'self' reference

* Fix lint (#40558)

---------

Co-authored-by: Alden S. Page <alden.page@doubleverify.com>
  • Loading branch information
aldenstpage and aldenpagedv authored Jul 3, 2024
1 parent a8c4830 commit 7e80dc6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
11 changes: 11 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,20 @@ class BigQueryCheckOperator(
:param deferrable: Run operator in the deferrable mode.
:param poll_interval: (Deferrable mode only) polling period in seconds to
check for the status of job.
:param query_params: a list of dictionary containing query parameter types and
values, passed to BigQuery. The structure of dictionary should look like
'queryParameters' in Google BigQuery Jobs API:
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs.
For example, [{ 'name': 'corpus', 'parameterType': { 'type': 'STRING' },
'parameterValue': { 'value': 'romeoandjuliet' } }]. (templated)
"""

template_fields: Sequence[str] = (
"sql",
"gcp_conn_id",
"impersonation_chain",
"labels",
"query_params",
)
template_ext: Sequence[str] = (".sql",)
ui_color = BigQueryUIColors.CHECK.value
Expand All @@ -246,6 +253,7 @@ def __init__(
encryption_configuration: dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: float = 4.0,
query_params: list | None = None,
**kwargs,
) -> None:
super().__init__(sql=sql, **kwargs)
Expand All @@ -257,6 +265,7 @@ def __init__(
self.encryption_configuration = encryption_configuration
self.deferrable = deferrable
self.poll_interval = poll_interval
self.query_params = query_params

def _submit_job(
self,
Expand All @@ -265,6 +274,8 @@ def _submit_job(
) -> BigQueryJob:
"""Submit a new job and get the job id for polling the status using Trigger."""
configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}}
if self.query_params:
configuration["query"]["queryParameters"] = self.query_params

self.include_encryption_configuration(configuration, "query")

Expand Down
31 changes: 30 additions & 1 deletion tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import pandas as pd
import pytest
from google.cloud.bigquery import DEFAULT_RETRY
from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter
from google.cloud.exceptions import Conflict
from openlineage.client.facet import ErrorMessageRunFacet, ExternalQueryRunFacet, SqlJobFacet
from openlineage.client.run import Dataset
Expand Down Expand Up @@ -2293,6 +2293,35 @@ def test_bigquery_check_operator_async_finish_before_deferred(
mock_defer.assert_not_called()
mock_validate_records.assert_called_once_with((1, 2, 3))

@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator._validate_records")
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_check_operator_query_parameters_passing(
self, mock_hook, mock_validate_records, create_task_instance_of_operator
):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
query_params = [ScalarQueryParameter("test_param", "INT64", 1)]

mocked_job = MagicMock(job_id=real_job_id, error_result=False)
mocked_job.result.return_value = iter([(1, 2, 3)]) # mock rows generator
mock_hook.return_value.insert_job.return_value = mocked_job
mock_hook.return_value.insert_job.return_value.running.return_value = False

ti = create_task_instance_of_operator(
BigQueryCheckOperator,
dag_id="dag_id",
task_id="bq_check_operator_query_params_job",
sql="SELECT * FROM any WHERE test_param = @test_param",
location=TEST_DATASET_LOCATION,
deferrable=True,
query_params=query_params,
)

ti.task.execute(MagicMock())
mock_validate_records.assert_called_once_with((1, 2, 3))

@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_check_operator_async_finish_with_error_before_deferred(
Expand Down

0 comments on commit 7e80dc6

Please sign in to comment.