Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add use_legacy_sql param to BigQueryGetDataOperator #31190

Merged
merged 1 commit into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
Defaults to 4 seconds.
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists
(default: False).
:param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false).
"""

template_fields: Sequence[str] = (
Expand All @@ -845,6 +846,7 @@ def __init__(
deferrable: bool = False,
poll_interval: float = 4.0,
as_dict: bool = False,
use_legacy_sql: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -860,14 +862,15 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval
self.as_dict = as_dict
self.use_legacy_sql = use_legacy_sql

def _submit_job(
self,
hook: BigQueryHook,
job_id: str,
) -> BigQueryJob:
get_query = self.generate_query()
configuration = {"query": {"query": get_query}}
configuration = {"query": {"query": get_query, "useLegacySql": self.use_legacy_sql}}
"""Submit a new job and get the job id for polling the status using Triggerer."""
return hook.insert_job(
configuration=configuration,
Expand All @@ -887,18 +890,23 @@ def generate_query(self) -> str:
query += self.selected_fields
else:
query += "*"
query += f" from {self.dataset_id}.{self.table_id} limit {self.max_results}"
query += f" from `{self.project_id}.{self.dataset_id}.{self.table_id}` limit {self.max_results}"
return query

def execute(self, context: Context):
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
use_legacy_sql=self.use_legacy_sql,
)

if not self.deferrable:
self.log.info(
"Fetching Data from %s.%s max results: %s", self.dataset_id, self.table_id, self.max_results
"Fetching Data from %s.%s.%s max results: %s",
self.project_id,
self.dataset_id,
self.table_id,
self.max_results,
)
if not self.selected_fields:
schema: dict[str, list] = hook.get_schema(
Expand Down
20 changes: 8 additions & 12 deletions tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"refreshIntervalMs": 2000000,
}
TEST_TABLE = "test-table"
GCP_CONN_ID = "google_cloud_default"


class TestBigQueryCreateEmptyTableOperator:
Expand Down Expand Up @@ -791,6 +792,7 @@ def test_execute(self, mock_hook, as_dict):
max_results = 100
selected_fields = "DATE"
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
Expand All @@ -799,8 +801,10 @@ def test_execute(self, mock_hook, as_dict):
selected_fields=selected_fields,
location=TEST_DATASET_LOCATION,
as_dict=as_dict,
use_legacy_sql=False,
)
operator.execute(None)
mock_hook.assert_called_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, use_legacy_sql=False)
mock_hook.return_value.list_rows.assert_called_once_with(
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
Expand All @@ -818,12 +822,6 @@ def test_bigquery_get_data_operator_async_with_selected_fields(
Asserts that a task is deferred and a BigQuerygetDataTrigger will be fired
when the BigQueryGetDataOperator is executed with deferrable=True.
"""
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)

ti = create_task_instance_of_operator(
BigQueryGetDataOperator,
dag_id="dag_id",
Expand All @@ -833,6 +831,7 @@ def test_bigquery_get_data_operator_async_with_selected_fields(
max_results=100,
selected_fields="value,name",
deferrable=True,
use_legacy_sql=False,
)

with pytest.raises(TaskDeferred) as exc:
Expand All @@ -851,12 +850,6 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
Asserts that a task is deferred and a BigQueryGetDataTrigger will be fired
when the BigQueryGetDataOperator is executed with deferrable=True.
"""
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)

ti = create_task_instance_of_operator(
BigQueryGetDataOperator,
dag_id="dag_id",
Expand All @@ -866,6 +859,7 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
max_results=100,
deferrable=True,
as_dict=as_dict,
use_legacy_sql=False,
)

with pytest.raises(TaskDeferred) as exc:
Expand All @@ -886,6 +880,7 @@ def test_bigquery_get_data_operator_execute_failure(self, as_dict):
max_results=100,
deferrable=True,
as_dict=as_dict,
use_legacy_sql=False,
)

with pytest.raises(AirflowException):
Expand All @@ -904,6 +899,7 @@ def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict):
max_results=100,
deferrable=True,
as_dict=as_dict,
use_legacy_sql=False,
)

with mock.patch.object(operator.log, "info") as mock_log_info:
Expand Down