Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for query parameters to BigQueryCheckOperator (#40556) #40558

Merged
merged 4 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,20 @@ class BigQueryCheckOperator(
:param deferrable: Run operator in the deferrable mode.
:param poll_interval: (Deferrable mode only) polling period in seconds to
check for the status of job.
:param query_params: a list of dictionary containing query parameter types and
values, passed to BigQuery. The structure of dictionary should look like
'queryParameters' in Google BigQuery Jobs API:
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs.
For example, [{ 'name': 'corpus', 'parameterType': { 'type': 'STRING' },
'parameterValue': { 'value': 'romeoandjuliet' } }]. (templated)
"""

template_fields: Sequence[str] = (
"sql",
"gcp_conn_id",
"impersonation_chain",
"labels",
"query_params",
)
template_ext: Sequence[str] = (".sql",)
ui_color = BigQueryUIColors.CHECK.value
Expand All @@ -246,6 +253,7 @@ def __init__(
encryption_configuration: dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: float = 4.0,
query_params: list | None = None,
**kwargs,
) -> None:
super().__init__(sql=sql, **kwargs)
Expand All @@ -257,6 +265,7 @@ def __init__(
self.encryption_configuration = encryption_configuration
self.deferrable = deferrable
self.poll_interval = poll_interval
self.query_params = query_params

def _submit_job(
self,
Expand All @@ -265,6 +274,8 @@ def _submit_job(
) -> BigQueryJob:
"""Submit a new job and get the job id for polling the status using Trigger."""
configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}}
if self.query_params:
configuration["query"]["queryParameters"] = self.query_params

self.include_encryption_configuration(configuration, "query")

Expand Down
31 changes: 30 additions & 1 deletion tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import pandas as pd
import pytest
from google.cloud.bigquery import DEFAULT_RETRY
from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter
from google.cloud.exceptions import Conflict
from openlineage.client.facet import ErrorMessageRunFacet, ExternalQueryRunFacet, SqlJobFacet
from openlineage.client.run import Dataset
Expand Down Expand Up @@ -2293,6 +2293,35 @@ def test_bigquery_check_operator_async_finish_before_deferred(
mock_defer.assert_not_called()
mock_validate_records.assert_called_once_with((1, 2, 3))

@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator._validate_records")
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_check_operator_query_parameters_passing(
self, mock_hook, mock_validate_records, create_task_instance_of_operator
):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
query_params = [ScalarQueryParameter("test_param", "INT64", 1)]

mocked_job = MagicMock(job_id=real_job_id, error_result=False)
mocked_job.result.return_value = iter([(1, 2, 3)]) # mock rows generator
mock_hook.return_value.insert_job.return_value = mocked_job
mock_hook.return_value.insert_job.return_value.running.return_value = False

ti = create_task_instance_of_operator(
BigQueryCheckOperator,
dag_id="dag_id",
task_id="bq_check_operator_query_params_job",
sql="SELECT * FROM any WHERE test_param = @test_param",
location=TEST_DATASET_LOCATION,
deferrable=True,
query_params=query_params
)

ti.task.execute(MagicMock())
mock_validate_records.assert_called_once_with((1, 2, 3))

@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_check_operator_async_finish_with_error_before_deferred(
Expand Down
Loading