From 30a609ae3b25ea71b9e5e333be897d9ef6555f65 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Thu, 4 Jul 2024 10:46:27 +0200 Subject: [PATCH] fix OpenLineage extraction for AthenaOperator Signed-off-by: Kacper Muda --- .../providers/amazon/aws/operators/athena.py | 26 ++++++++++++++----- .../amazon/aws/operators/test_athena.py | 23 +++++++++++++++- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 5d30b93143e92..0178d60a12c9b 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -175,9 +175,6 @@ def execute(self, context: Context) -> str | None: f"query_execution_id is {self.query_execution_id}." ) - # Save output location from API response for later use in OpenLineage. - self.output_location = self.hook.get_output_location(self.query_execution_id) - return self.query_execution_id def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: @@ -185,6 +182,9 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if event["status"] != "success": raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}") + + # Save query_execution_id to be later used by listeners + self.query_execution_id = event["value"] return event["value"] def on_kill(self) -> None: @@ -208,14 +208,21 @@ def on_kill(self) -> None: ) self.hook.poll_query_status(self.query_execution_id, sleep_time=self.sleep_time) - def get_openlineage_facets_on_start(self) -> OperatorLineage: + def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: """ Retrieve OpenLineage data by parsing SQL queries and enriching them with Athena API. In addition to CTAS query, query and calculation results are stored in S3 location. - For that reason additional output is attached with this location. + For that reason additional output is attached with this location. Instead of using the complete + path where the results are saved (user's prefix + some UUID), we are creating a dataset with the + user-provided path only. This should make it easier to match this dataset across different processes. """ - from openlineage.client.facet import ExtractionError, ExtractionErrorRunFacet, SqlJobFacet + from openlineage.client.facet import ( + ExternalQueryRunFacet, + ExtractionError, + ExtractionErrorRunFacet, + SqlJobFacet, + ) from openlineage.client.run import Dataset from airflow.providers.openlineage.extractors.base import OperatorLineage @@ -265,6 +272,11 @@ def get_openlineage_facets_on_start(self) -> OperatorLineage: ) ) + if self.query_execution_id: + run_facets["externalQuery"] = ExternalQueryRunFacet( + externalQueryId=self.query_execution_id, source="awsathena" + ) + if self.output_location: parsed = urlparse(self.output_location) outputs.append(Dataset(namespace=f"{parsed.scheme}://{parsed.netloc}", name=parsed.path or "/")) @@ -301,7 +313,7 @@ def get_openlineage_dataset(self, database, table) -> Dataset | None: ) } fields = [ - SchemaField(name=column["Name"], type=column["Type"], description=column["Comment"]) + SchemaField(name=column["Name"], type=column["Type"], description=column.get("Comment")) for column in table_metadata["TableMetadata"]["Columns"] ] if fields: diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 66fb6b297f993..5d5a6b88c35f9 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -21,6 +21,7 @@ import pytest from openlineage.client.facet import ( + ExternalQueryRunFacet, SchemaDatasetFacet, SchemaField, SqlJobFacet, @@ -264,6 +265,24 @@ def test_is_deferred(self, mock_run_query): query_execution_id=ATHENA_QUERY_ID, ) + def test_execute_complete_reassigns_query_execution_id_after_deferring(self): + """Assert that we use query_execution_id from event after deferral.""" + + operator = AthenaOperator( + task_id="test_athena_operator", + query="SELECT * FROM TEST_TABLE", + database="TEST_DATABASE", + deferrable=True, + ) + assert operator.query_execution_id is None + + query_execution_id = "123456" + operator.execute_complete( + context=None, + event={"status": "success", "value": query_execution_id}, + ) + assert operator.query_execution_id == query_execution_id + @mock.patch.object(AthenaHook, "region_name", new_callable=mock.PropertyMock) @mock.patch.object(AthenaHook, "get_conn") def test_operator_openlineage_data(self, mock_conn, mock_region_name): @@ -285,6 +304,7 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): max_polling_attempts=3, dag=self.dag, ) + op.query_execution_id = "12345" # Mocking what will be available after execution expected_lineage = OperatorLineage( inputs=[ @@ -365,5 +385,6 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName): query="INSERT INTO TEST_TABLE SELECT CUSTOMER_EMAIL FROM DISCOUNTS", ) }, + run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="12345", source="awsathena")}, ) - assert op.get_openlineage_facets_on_start() == expected_lineage + assert op.get_openlineage_facets_on_complete(None) == expected_lineage