From 026887ed84a3cbe78387c9261f7b4e229f322a97 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Tue, 10 Dec 2024 11:35:13 +0100 Subject: [PATCH] feat: add OpenLineage support for BigQuery Create Table operators (#44783) Signed-off-by: Kacper Muda --- .../google/cloud/openlineage/utils.py | 17 +++ .../google/cloud/operators/bigquery.py | 65 +++++++-- .../google/cloud/openlineage/test_utils.py | 12 ++ .../google/cloud/operators/test_bigquery.py | 124 +++++++++++++++++- 4 files changed, 204 insertions(+), 14 deletions(-) diff --git a/providers/src/airflow/providers/google/cloud/openlineage/utils.py b/providers/src/airflow/providers/google/cloud/openlineage/utils.py index 6b8c93063fc39..a8989b4eb8f2c 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/utils.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/utils.py @@ -33,12 +33,15 @@ ColumnLineageDatasetFacet, DocumentationDatasetFacet, Fields, + Identifier, InputField, RunFacet, SchemaDatasetFacet, SchemaDatasetFacetFields, + SymlinksDatasetFacet, ) from airflow.providers.google import __version__ as provider_version +from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url BIGQUERY_NAMESPACE = "bigquery" BIGQUERY_URI = "bigquery" @@ -113,6 +116,20 @@ def get_facets_from_bq_table(table: Table) -> dict[str, BaseFacet]: if table.description: facets["documentation"] = DocumentationDatasetFacet(description=table.description) + if table.external_data_configuration: + symlinks = set() + for uri in table.external_data_configuration.source_uris: + if uri.startswith("gs://"): + bucket, blob = _parse_gcs_url(uri) + blob = extract_ds_name_from_gcs_path(blob) + symlinks.add((f"gs://{bucket}", blob)) + + facets["symlink"] = SymlinksDatasetFacet( + identifiers=[ + Identifier(namespace=namespace, name=name, type="file") + for namespace, name in sorted(symlinks) + ] + ) return facets diff --git a/providers/src/airflow/providers/google/cloud/operators/bigquery.py b/providers/src/airflow/providers/google/cloud/operators/bigquery.py index ee4d8fbc5b134..044e2bf6ee808 100644 --- a/providers/src/airflow/providers/google/cloud/operators/bigquery.py +++ b/providers/src/airflow/providers/google/cloud/operators/bigquery.py @@ -1365,7 +1365,7 @@ def execute(self, context: Context) -> None: try: self.log.info("Creating table") - table = bq_hook.create_empty_table( + self._table = bq_hook.create_empty_table( project_id=self.project_id, dataset_id=self.dataset_id, table_id=self.table_id, @@ -1382,12 +1382,15 @@ def execute(self, context: Context) -> None: persist_kwargs = { "context": context, "task_instance": self, - "project_id": table.to_api_repr()["tableReference"]["projectId"], - "dataset_id": table.to_api_repr()["tableReference"]["datasetId"], - "table_id": table.to_api_repr()["tableReference"]["tableId"], + "project_id": self._table.to_api_repr()["tableReference"]["projectId"], + "dataset_id": self._table.to_api_repr()["tableReference"]["datasetId"], + "table_id": self._table.to_api_repr()["tableReference"]["tableId"], } self.log.info( - "Table %s.%s.%s created successfully", table.project, table.dataset_id, table.table_id + "Table %s.%s.%s created successfully", + self._table.project, + self._table.dataset_id, + self._table.table_id, ) except Conflict: error_msg = f"Table {self.dataset_id}.{self.table_id} already exists." @@ -1407,6 +1410,24 @@ def execute(self, context: Context) -> None: BigQueryTableLink.persist(**persist_kwargs) + def get_openlineage_facets_on_complete(self, task_instance): + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.google.cloud.openlineage.utils import ( + BIGQUERY_NAMESPACE, + get_facets_from_bq_table, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + table_info = self._table.to_api_repr()["tableReference"] + table_id = ".".join((table_info["projectId"], table_info["datasetId"], table_info["tableId"])) + output_dataset = Dataset( + namespace=BIGQUERY_NAMESPACE, + name=table_id, + facets=get_facets_from_bq_table(self._table), + ) + + return OperatorLineage(outputs=[output_dataset]) + class BigQueryCreateExternalTableOperator(GoogleCloudBaseOperator): """ @@ -1632,15 +1653,15 @@ def execute(self, context: Context) -> None: impersonation_chain=self.impersonation_chain, ) if self.table_resource: - table = bq_hook.create_empty_table( + self._table = bq_hook.create_empty_table( table_resource=self.table_resource, ) BigQueryTableLink.persist( context=context, task_instance=self, - dataset_id=table.to_api_repr()["tableReference"]["datasetId"], - project_id=table.to_api_repr()["tableReference"]["projectId"], - table_id=table.to_api_repr()["tableReference"]["tableId"], + dataset_id=self._table.to_api_repr()["tableReference"]["datasetId"], + project_id=self._table.to_api_repr()["tableReference"]["projectId"], + table_id=self._table.to_api_repr()["tableReference"]["tableId"], ) return @@ -1691,18 +1712,36 @@ def execute(self, context: Context) -> None: "encryptionConfiguration": self.encryption_configuration, } - table = bq_hook.create_empty_table( + self._table = bq_hook.create_empty_table( table_resource=table_resource, ) BigQueryTableLink.persist( context=context, task_instance=self, - dataset_id=table.to_api_repr()["tableReference"]["datasetId"], - project_id=table.to_api_repr()["tableReference"]["projectId"], - table_id=table.to_api_repr()["tableReference"]["tableId"], + dataset_id=self._table.to_api_repr()["tableReference"]["datasetId"], + project_id=self._table.to_api_repr()["tableReference"]["projectId"], + table_id=self._table.to_api_repr()["tableReference"]["tableId"], ) + def get_openlineage_facets_on_complete(self, task_instance): + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.google.cloud.openlineage.utils import ( + BIGQUERY_NAMESPACE, + get_facets_from_bq_table, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + table_info = self._table.to_api_repr()["tableReference"] + table_id = ".".join((table_info["projectId"], table_info["datasetId"], table_info["tableId"])) + output_dataset = Dataset( + namespace=BIGQUERY_NAMESPACE, + name=table_id, + facets=get_facets_from_bq_table(self._table), + ) + + return OperatorLineage(outputs=[output_dataset]) + class BigQueryDeleteDatasetOperator(GoogleCloudBaseOperator): """ diff --git a/providers/tests/google/cloud/openlineage/test_utils.py b/providers/tests/google/cloud/openlineage/test_utils.py index 21efbf5de1655..d580e9b376a9e 100644 --- a/providers/tests/google/cloud/openlineage/test_utils.py +++ b/providers/tests/google/cloud/openlineage/test_utils.py @@ -27,9 +27,11 @@ Dataset, DocumentationDatasetFacet, Fields, + Identifier, InputField, SchemaDatasetFacet, SchemaDatasetFacetFields, + SymlinksDatasetFacet, ) from airflow.providers.google.cloud.openlineage.utils import ( extract_ds_name_from_gcs_path, @@ -49,6 +51,10 @@ {"name": "field2", "type": "INTEGER"}, ] }, + "externalDataConfiguration": { + "sourceFormat": "CSV", + "sourceUris": ["gs://bucket/path/to/files*", "gs://second_bucket/path/to/other/files*"], + }, } TEST_TABLE: Table = Table.from_api_repr(TEST_TABLE_API_REPR) TEST_EMPTY_TABLE_API_REPR = { @@ -84,6 +90,12 @@ def test_get_facets_from_bq_table(): ] ), "documentation": DocumentationDatasetFacet(description="Table description."), + "symlink": SymlinksDatasetFacet( + identifiers=[ + Identifier(namespace="gs://bucket", name="path/to", type="file"), + Identifier(namespace="gs://second_bucket", name="path/to/other", type="file"), + ] + ), } result = get_facets_from_bq_table(TEST_TABLE) assert result == expected_facets diff --git a/providers/tests/google/cloud/operators/test_bigquery.py b/providers/tests/google/cloud/operators/test_bigquery.py index 26aa18ae52f13..29f3a8db13e2c 100644 --- a/providers/tests/google/cloud/operators/test_bigquery.py +++ b/providers/tests/google/cloud/operators/test_bigquery.py @@ -26,7 +26,7 @@ import pandas as pd import pytest -from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter +from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter, Table from google.cloud.exceptions import Conflict from airflow.exceptions import ( @@ -36,11 +36,17 @@ TaskDeferred, ) from airflow.providers.common.compat.openlineage.facet import ( + DocumentationDatasetFacet, ErrorMessageRunFacet, ExternalQueryRunFacet, + Identifier, InputDataset, + SchemaDatasetFacet, + SchemaDatasetFacetFields, SQLJobFacet, + SymlinksDatasetFacet, ) +from airflow.providers.google.cloud.openlineage.utils import BIGQUERY_NAMESPACE from airflow.providers.google.cloud.operators.bigquery import ( BigQueryCheckOperator, BigQueryColumnCheckOperator, @@ -259,6 +265,63 @@ def test_create_existing_table(self, mock_hook, caplog, if_exists, is_conflict, if log_msg is not None: assert log_msg in caplog.text + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_get_openlineage_facets_on_complete(self, mock_hook): + schema_fields = [ + {"name": "field1", "type": "STRING", "description": "field1 description"}, + {"name": "field2", "type": "INTEGER"}, + ] + table_resource = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + "description": "Table description.", + "schema": {"fields": schema_fields}, + } + mock_hook.return_value.create_empty_table.return_value = Table.from_api_repr(table_resource) + operator = BigQueryCreateEmptyTableOperator( + task_id=TASK_ID, + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=schema_fields, + ) + operator.execute(context=MagicMock()) + + mock_hook.return_value.create_empty_table.assert_called_once_with( + dataset_id=TEST_DATASET, + project_id=TEST_GCP_PROJECT_ID, + table_id=TEST_TABLE_ID, + schema_fields=schema_fields, + time_partitioning={}, + cluster_fields=None, + labels=None, + view=None, + materialized_view=None, + encryption_configuration=None, + table_resource=None, + exists_ok=False, + ) + + result = operator.get_openlineage_facets_on_complete(None) + assert not result.run_facets + assert not result.job_facets + assert not result.inputs + assert len(result.outputs) == 1 + assert result.outputs[0].namespace == BIGQUERY_NAMESPACE + assert result.outputs[0].name == f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" + assert result.outputs[0].facets == { + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet(description="Table description."), + } + class TestBigQueryCreateExternalTableOperator: @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") @@ -344,6 +407,65 @@ def test_execute_with_parquet_format(self, mock_hook): operator.execute(context=MagicMock()) mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_get_openlineage_facets_on_complete(self, mock_hook): + table_resource = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + "description": "Table description.", + "schema": { + "fields": [ + {"name": "field1", "type": "STRING", "description": "field1 description"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, + "externalDataConfiguration": { + "sourceUris": [ + f"gs://{TEST_GCS_BUCKET}/{source_object}" for source_object in TEST_GCS_CSV_DATA + ], + "sourceFormat": TEST_SOURCE_CSV_FORMAT, + }, + } + mock_hook.return_value.create_empty_table.return_value = Table.from_api_repr(table_resource) + operator = BigQueryCreateExternalTableOperator( + task_id=TASK_ID, + bucket=TEST_GCS_BUCKET, + source_objects=TEST_GCS_CSV_DATA, + table_resource=table_resource, + ) + + mock_hook.return_value.split_tablename.return_value = ( + TEST_GCP_PROJECT_ID, + TEST_DATASET, + TEST_TABLE_ID, + ) + + operator.execute(context=MagicMock()) + mock_hook.return_value.create_empty_table.assert_called_once_with(table_resource=table_resource) + + result = operator.get_openlineage_facets_on_complete(None) + assert not result.run_facets + assert not result.job_facets + assert not result.inputs + assert len(result.outputs) == 1 + assert result.outputs[0].namespace == BIGQUERY_NAMESPACE + assert result.outputs[0].name == f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" + assert result.outputs[0].facets == { + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet(description="Table description."), + "symlink": SymlinksDatasetFacet( + identifiers=[Identifier(namespace=f"gs://{TEST_GCS_BUCKET}", name="dir1", type="file")] + ), + } + class TestBigQueryDeleteDatasetOperator: @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")