Skip to content

Commit

Permalink
Fix: re-enable use of parameters in gcs_to_bq which had been disabled (
Browse files Browse the repository at this point in the history
  • Loading branch information
mdering authored Dec 4, 2022
1 parent 5cdff50 commit 2d663df
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 1 deletion.
26 changes: 25 additions & 1 deletion airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

from airflow import AirflowException
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
from airflow.providers.google.cloud.hooks.bigquery import (
BigQueryHook,
BigQueryJob,
_cleanse_time_partitioning,
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger
Expand Down Expand Up @@ -390,8 +394,28 @@ def execute(self, context: Context):
"ignoreUnknownValues": self.ignore_unknown_values,
"allowQuotedNewlines": self.allow_quoted_newlines,
"encoding": self.encoding,
"allowJaggedRows": self.allow_jagged_rows,
"fieldDelimiter": self.field_delimiter,
"maxBadRecords": self.max_bad_records,
"quote": self.quote_character,
"schemaUpdateOptions": self.schema_update_options,
},
}
if self.cluster_fields:
self.configuration["load"].update({"clustering": {"fields": self.cluster_fields}})
time_partitioning = _cleanse_time_partitioning(
self.destination_project_dataset_table, self.time_partitioning
)
if time_partitioning:
self.configuration["load"].update({"timePartitioning": time_partitioning})
# fields that should only be set if defined
set_if_def = {
"quote": self.quote_character,
"destinationEncryptionConfiguration": self.encryption_configuration,
}
for k, v in set_if_def.items():
if v:
self.configuration["load"][k] = v
self.configuration = self._check_schema_fields(self.configuration)
try:
self.log.info("Executing: %s", self.configuration)
Expand Down
228 changes: 228 additions & 0 deletions tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def test_max_value_without_external_table_should_execute_successfully(self, hook
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -226,6 +231,11 @@ def test_max_value_should_throw_ex_when_query_returns_no_rows(self, hook):
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -335,6 +345,11 @@ def test_labels_without_external_table_should_execute_successfully(self, hook):
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -434,6 +449,11 @@ def test_description_without_external_table_should_execute_successfully(self, ho
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -535,6 +555,11 @@ def test_source_objs_as_list_without_external_table_should_execute_successfully(
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -632,6 +657,194 @@ def test_source_objs_as_string_without_external_table_should_execute_successfull
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
location=None,
job_id=pytest.real_job_id,
timeout=None,
retry=DEFAULT_RETRY,
nowait=True,
),
]

hook.return_value.insert_job.assert_has_calls(calls)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
def test_all_fields_should_be_present(self, hook):
hook.return_value.insert_job.side_effect = [
MagicMock(job_id=pytest.real_job_id, error_result=False),
pytest.real_job_id,
]
hook.return_value.generate_job_id.return_value = pytest.real_job_id
hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
schema_fields=SCHEMA_FIELDS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
write_disposition=WRITE_DISPOSITION,
external_table=False,
field_delimiter=";",
max_bad_records=13,
quote_character="|",
schema_update_options={"foo": "bar"},
allow_jagged_rows=True,
encryption_configuration={"bar": "baz"},
cluster_fields=["field_1", "field_2"],
)

operator.execute(context=MagicMock())

calls = [
call(
configuration={
"load": dict(
autodetect=True,
createDisposition="CREATE_IF_NEEDED",
destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
destinationTableProperties={
"description": None,
"labels": None,
},
sourceFormat="CSV",
skipLeadingRows=None,
sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
writeDisposition=WRITE_DISPOSITION,
ignoreUnknownValues=False,
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=True,
fieldDelimiter=";",
maxBadRecords=13,
quote="|",
schemaUpdateOptions={"foo": "bar"},
destinationEncryptionConfiguration={"bar": "baz"},
clustering={"fields": ["field_1", "field_2"]},
),
},
project_id=hook.return_value.project_id,
location=None,
job_id=pytest.real_job_id,
timeout=None,
retry=DEFAULT_RETRY,
nowait=True,
),
]

hook.return_value.insert_job.assert_has_calls(calls)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
def test_date_partitioned_explicit_setting_should_be_found(self, hook):
hook.return_value.insert_job.side_effect = [
MagicMock(job_id=pytest.real_job_id, error_result=False),
pytest.real_job_id,
]
hook.return_value.generate_job_id.return_value = pytest.real_job_id
hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
schema_fields=SCHEMA_FIELDS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
write_disposition=WRITE_DISPOSITION,
external_table=False,
time_partitioning={"type": "DAY"},
)

operator.execute(context=MagicMock())

calls = [
call(
configuration={
"load": dict(
autodetect=True,
createDisposition="CREATE_IF_NEEDED",
destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
destinationTableProperties={
"description": None,
"labels": None,
},
sourceFormat="CSV",
skipLeadingRows=None,
sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
writeDisposition=WRITE_DISPOSITION,
ignoreUnknownValues=False,
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
timePartitioning={"type": "DAY"},
),
},
project_id=hook.return_value.project_id,
location=None,
job_id=pytest.real_job_id,
timeout=None,
retry=DEFAULT_RETRY,
nowait=True,
),
]

hook.return_value.insert_job.assert_has_calls(calls)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
def test_date_partitioned_implied_in_table_name_should_be_found(self, hook):
hook.return_value.insert_job.side_effect = [
MagicMock(job_id=pytest.real_job_id, error_result=False),
pytest.real_job_id,
]
hook.return_value.generate_job_id.return_value = pytest.real_job_id
hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
schema_fields=SCHEMA_FIELDS,
destination_project_dataset_table=TEST_EXPLICIT_DEST + "$20221123",
write_disposition=WRITE_DISPOSITION,
external_table=False,
)

operator.execute(context=MagicMock())

calls = [
call(
configuration={
"load": dict(
autodetect=True,
createDisposition="CREATE_IF_NEEDED",
destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
destinationTableProperties={
"description": None,
"labels": None,
},
sourceFormat="CSV",
skipLeadingRows=None,
sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
writeDisposition=WRITE_DISPOSITION,
ignoreUnknownValues=False,
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
timePartitioning={"type": "DAY"},
),
},
project_id=hook.return_value.project_id,
Expand Down Expand Up @@ -830,6 +1043,11 @@ def test_schema_fields_scanner_without_external_table_should_execute_successfull
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=bq_hook.return_value.project_id,
Expand Down Expand Up @@ -1023,6 +1241,11 @@ def test_schema_fields_integer_scanner_without_external_table_should_execute_suc
ignoreUnknownValues=False,
allowQuotedNewlines=False,
encoding="UTF-8",
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=bq_hook.return_value.project_id,
Expand Down Expand Up @@ -1087,6 +1310,11 @@ def test_schema_fields_without_external_table_should_execute_successfully(self,
allowQuotedNewlines=False,
encoding="UTF-8",
schema={"fields": SCHEMA_FIELDS_INT},
allowJaggedRows=False,
fieldDelimiter=",",
maxBadRecords=0,
quote=None,
schemaUpdateOptions=(),
),
},
project_id=hook.return_value.project_id,
Expand Down

0 comments on commit 2d663df

Please sign in to comment.