From 9bae4a5b48256982310d079df8ece0d95a643cc9 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Wed, 27 Nov 2024 12:48:27 +0100 Subject: [PATCH] chore: unify handling of gcs paths (#44410) Signed-off-by: Kacper Muda --- .../google/cloud/openlineage/utils.py | 38 +++++++++++++++++++ .../providers/google/cloud/operators/gcs.py | 19 +++------- .../google/cloud/transfers/bigquery_to_gcs.py | 17 +++------ .../google/cloud/transfers/gcs_to_bigquery.py | 15 +++----- .../google/cloud/transfers/gcs_to_gcs.py | 18 ++------- .../google/cloud/openlineage/test_utils.py | 25 ++++++++++++ 6 files changed, 82 insertions(+), 50 deletions(-) diff --git a/providers/src/airflow/providers/google/cloud/openlineage/utils.py b/providers/src/airflow/providers/google/cloud/openlineage/utils.py index 403023f7b4315..ff6c4c05bb6a4 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/utils.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/utils.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import os +import pathlib from typing import TYPE_CHECKING, Any from attr import define, field @@ -40,6 +42,42 @@ BIGQUERY_NAMESPACE = "bigquery" BIGQUERY_URI = "bigquery" +WILDCARD = "*" + + +def extract_ds_name_from_gcs_path(path: str) -> str: + """ + Extract and process the dataset name from a given path. + + Args: + path: The path to process e.g. of a gcs file. + + Returns: + The processed dataset name. + """ + if WILDCARD in path: + path = path.split(WILDCARD, maxsplit=1)[0] + + # We want to end up with parent directory if the path: + # - does not refer to a file (no dot in the last segment) + # and does not explicitly end with a slash, it is treated as a prefix and removed. + # Example: "/dir/pre_" -> "/dir/" + # - contains a dot at the end, then it is treated as a prefix (created after removing the wildcard). + # Example: "/dir/file." (was "/dir/file.*" with wildcard) -> "/dir/" + last_path_segment = os.path.basename(path).rstrip(".") + if "." not in last_path_segment and not path.endswith("/"): + path = pathlib.Path(path).parent.as_posix() + + # Normalize the path: + # - Remove trailing slashes. + # - Remove leading slashes. + # - Handle edge cases for empty paths or single-dot paths. + path = path.rstrip("/") + path = path.lstrip("/") + if path in ("", "."): + path = "/" + + return path def get_facets_from_bq_table(table: Table) -> dict[str, BaseFacet]: diff --git a/providers/src/airflow/providers/google/cloud/operators/gcs.py b/providers/src/airflow/providers/google/cloud/operators/gcs.py index 0cb8527d526da..55835219dfe5a 100644 --- a/providers/src/airflow/providers/google/cloud/operators/gcs.py +++ b/providers/src/airflow/providers/google/cloud/operators/gcs.py @@ -343,6 +343,7 @@ def get_openlineage_facets_on_start(self): LifecycleStateChangeDatasetFacet, PreviousIdentifier, ) + from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path from airflow.providers.openlineage.extractors import OperatorLineage objects = [] @@ -350,12 +351,7 @@ def get_openlineage_facets_on_start(self): objects = self.objects elif self.prefix is not None: prefixes = [self.prefix] if isinstance(self.prefix, str) else self.prefix - for pref in prefixes: - # Use parent if not a file (dot not in name) and not a dir (ends with slash) - if "." not in pref.split("/")[-1] and not pref.endswith("/"): - pref = Path(pref).parent.as_posix() - pref = "/" if pref in (".", "", "/") else pref.rstrip("/") - objects.append(pref) + objects = [extract_ds_name_from_gcs_path(pref) for pref in prefixes] bucket_url = f"gs://{self.bucket_name}" input_datasets = [ @@ -921,20 +917,15 @@ def execute(self, context: Context) -> list[str]: def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as execute() resolves object prefixes.""" from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path from airflow.providers.openlineage.extractors import OperatorLineage - def _parse_prefix(pref): - # Use parent if not a file (dot not in name) and not a dir (ends with slash) - if "." not in pref.split("/")[-1] and not pref.endswith("/"): - pref = Path(pref).parent.as_posix() - return "/" if pref in (".", "/", "") else pref.rstrip("/") - input_prefix, output_prefix = "/", "/" if self._source_prefix_interp is not None: - input_prefix = _parse_prefix(self._source_prefix_interp) + input_prefix = extract_ds_name_from_gcs_path(self._source_prefix_interp) if self._destination_prefix_interp is not None: - output_prefix = _parse_prefix(self._destination_prefix_interp) + output_prefix = extract_ds_name_from_gcs_path(self._destination_prefix_interp) return OperatorLineage( inputs=[ diff --git a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index b776d102f103f..5fc788a1dc26b 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -292,8 +292,6 @@ def execute_complete(self, context: Context, event: dict[str, Any]): def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will include final BQ job id.""" - from pathlib import Path - from airflow.providers.common.compat.openlineage.facet import ( BaseFacet, Dataset, @@ -303,6 +301,8 @@ def get_openlineage_facets_on_complete(self, task_instance): ) from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url from airflow.providers.google.cloud.openlineage.utils import ( + WILDCARD, + extract_ds_name_from_gcs_path, get_facets_from_bq_table, get_identity_column_lineage_facet, ) @@ -333,24 +333,19 @@ def get_openlineage_facets_on_complete(self, task_instance): output_datasets = [] for uri in sorted(self.destination_cloud_storage_uris): bucket, blob = _parse_gcs_url(uri) - additional_facets = {} - if "*" in blob: - # If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name, - # but we create a symlink to the full object path with wildcard. + additional_facets = {} + if WILDCARD in blob: + # For path with wildcard we attach a symlink with unmodified path. additional_facets = { "symlink": SymlinksDatasetFacet( identifiers=[Identifier(namespace=f"gs://{bucket}", name=blob, type="file")] ), } - blob = Path(blob).parent.as_posix() - if blob == ".": - # blob path does not have leading slash, but we need root dataset name to be "/" - blob = "/" dataset = Dataset( namespace=f"gs://{bucket}", - name=blob, + name=extract_ds_name_from_gcs_path(blob), facets=merge_dicts(output_dataset_facets, additional_facets), ) output_datasets.append(dataset) diff --git a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 7558ebf735d1b..291608501b09d 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -756,8 +756,6 @@ def on_kill(self) -> None: def get_openlineage_facets_on_complete(self, task_instance): """Implement on_complete as we will include final BQ job id.""" - from pathlib import Path - from airflow.providers.common.compat.openlineage.facet import ( Dataset, ExternalQueryRunFacet, @@ -765,6 +763,8 @@ def get_openlineage_facets_on_complete(self, task_instance): SymlinksDatasetFacet, ) from airflow.providers.google.cloud.openlineage.utils import ( + WILDCARD, + extract_ds_name_from_gcs_path, get_facets_from_bq_table, get_identity_column_lineage_facet, ) @@ -793,22 +793,17 @@ def get_openlineage_facets_on_complete(self, task_instance): for blob in sorted(source_objects): additional_facets = {} - if "*" in blob: - # If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name, - # but we create a symlink to the full object path with wildcard. + if WILDCARD in blob: + # For path with wildcard we attach a symlink with unmodified path. additional_facets = { "symlink": SymlinksDatasetFacet( identifiers=[Identifier(namespace=f"gs://{self.bucket}", name=blob, type="file")] ), } - blob = Path(blob).parent.as_posix() - if blob == ".": - # blob path does not have leading slash, but we need root dataset name to be "/" - blob = "/" dataset = Dataset( namespace=f"gs://{self.bucket}", - name=blob, + name=extract_ds_name_from_gcs_path(blob), facets=merge_dicts(input_dataset_facets, additional_facets), ) input_datasets.append(dataset) diff --git a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py index 07ebdef892b02..49225a4e5e629 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py +++ b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py @@ -551,28 +551,16 @@ def get_openlineage_facets_on_complete(self, task_instance): This means we won't have to normalize self.source_object and self.source_objects, destination bucket and so on. """ - from pathlib import Path - from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path from airflow.providers.openlineage.extractors import OperatorLineage - def _process_prefix(pref): - if WILDCARD in pref: - pref = pref.split(WILDCARD)[0] - # Use parent if not a file (dot not in name) and not a dir (ends with slash) - if "." not in pref.split("/")[-1] and not pref.endswith("/"): - pref = Path(pref).parent.as_posix() - return ["/" if pref in ("", "/", ".") else pref.rstrip("/")] # Adjust root path - - inputs = [] - for prefix in self.source_objects: - result = _process_prefix(prefix) - inputs.extend(result) + inputs = [extract_ds_name_from_gcs_path(path) for path in self.source_objects] if self.destination_object is None: outputs = inputs.copy() else: - outputs = _process_prefix(self.destination_object) + outputs = [extract_ds_name_from_gcs_path(self.destination_object)] return OperatorLineage( inputs=[ diff --git a/providers/tests/google/cloud/openlineage/test_utils.py b/providers/tests/google/cloud/openlineage/test_utils.py index e3f40bee1549e..21efbf5de1655 100644 --- a/providers/tests/google/cloud/openlineage/test_utils.py +++ b/providers/tests/google/cloud/openlineage/test_utils.py @@ -19,6 +19,7 @@ import json from unittest.mock import MagicMock +import pytest from google.cloud.bigquery.table import Table from airflow.providers.common.compat.openlineage.facet import ( @@ -31,6 +32,7 @@ SchemaDatasetFacetFields, ) from airflow.providers.google.cloud.openlineage.utils import ( + extract_ds_name_from_gcs_path, get_facets_from_bq_table, get_identity_column_lineage_facet, ) @@ -263,3 +265,26 @@ def test_get_identity_column_lineage_facet_no_input_datasets(): result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) assert result == {} + + +@pytest.mark.parametrize( + "input_path, expected_output", + [ + ("/path/to/file.txt", "path/to/file.txt"), # Full file path + ("file.txt", "file.txt"), # File path in root directory + ("/path/to/dir/", "path/to/dir"), # Directory path + ("/path/to/dir/*", "path/to/dir"), # Path with wildcard at the end + ("/path/to/dir/*.csv", "path/to/dir"), # Path with wildcard in file name + ("/path/to/dir/file.*", "path/to/dir"), # Path with wildcard in file extension + ("/path/to/*/dir/file.csv", "path/to"), # Path with wildcard in the middle + ("/path/to/dir/pre_", "path/to/dir"), # Path with prefix + ("/pre", "/"), # Prefix only + ("/*", "/"), # Wildcard after root slash + ("/", "/"), # Root path + ("", "/"), # Empty path + (".", "/"), # Current directory + ("*", "/"), # Wildcard only + ], +) +def test_extract_ds_name_from_gcs_path(input_path, expected_output): + assert extract_ds_name_from_gcs_path(input_path) == expected_output