Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: unify handling of gcs paths in OpenLineage processes #44410

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions providers/src/airflow/providers/google/cloud/openlineage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

import os
import pathlib
from typing import TYPE_CHECKING, Any

from attr import define, field
Expand All @@ -40,6 +42,42 @@

BIGQUERY_NAMESPACE = "bigquery"
BIGQUERY_URI = "bigquery"
WILDCARD = "*"


def extract_ds_name_from_gcs_path(path: str) -> str:
"""
Extract and process the dataset name from a given path.
Args:
path: The path to process e.g. of a gcs file.
Returns:
The processed dataset name.
"""
if WILDCARD in path:
path = path.split(WILDCARD, maxsplit=1)[0]

# We want to end up with parent directory if the path:
# - does not refer to a file (no dot in the last segment)
# and does not explicitly end with a slash, it is treated as a prefix and removed.
# Example: "/dir/pre_" -> "/dir/"
# - contains a dot at the end, then it is treated as a prefix (created after removing the wildcard).
# Example: "/dir/file." (was "/dir/file.*" with wildcard) -> "/dir/"
last_path_segment = os.path.basename(path).rstrip(".")
if "." not in last_path_segment and not path.endswith("/"):
path = pathlib.Path(path).parent.as_posix()

# Normalize the path:
# - Remove trailing slashes.
# - Remove leading slashes.
# - Handle edge cases for empty paths or single-dot paths.
path = path.rstrip("/")
path = path.lstrip("/")
if path in ("", "."):
path = "/"

return path


def get_facets_from_bq_table(table: Table) -> dict[str, BaseFacet]:
Expand Down
19 changes: 5 additions & 14 deletions providers/src/airflow/providers/google/cloud/operators/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,19 +343,15 @@ def get_openlineage_facets_on_start(self):
LifecycleStateChangeDatasetFacet,
PreviousIdentifier,
)
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path
from airflow.providers.openlineage.extractors import OperatorLineage

objects = []
if self.objects is not None:
objects = self.objects
elif self.prefix is not None:
prefixes = [self.prefix] if isinstance(self.prefix, str) else self.prefix
for pref in prefixes:
# Use parent if not a file (dot not in name) and not a dir (ends with slash)
if "." not in pref.split("/")[-1] and not pref.endswith("/"):
pref = Path(pref).parent.as_posix()
pref = "/" if pref in (".", "", "/") else pref.rstrip("/")
objects.append(pref)
objects = [extract_ds_name_from_gcs_path(pref) for pref in prefixes]

bucket_url = f"gs://{self.bucket_name}"
input_datasets = [
Expand Down Expand Up @@ -921,20 +917,15 @@ def execute(self, context: Context) -> list[str]:
def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as execute() resolves object prefixes."""
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path
from airflow.providers.openlineage.extractors import OperatorLineage

def _parse_prefix(pref):
# Use parent if not a file (dot not in name) and not a dir (ends with slash)
if "." not in pref.split("/")[-1] and not pref.endswith("/"):
pref = Path(pref).parent.as_posix()
return "/" if pref in (".", "/", "") else pref.rstrip("/")

input_prefix, output_prefix = "/", "/"
if self._source_prefix_interp is not None:
input_prefix = _parse_prefix(self._source_prefix_interp)
input_prefix = extract_ds_name_from_gcs_path(self._source_prefix_interp)

if self._destination_prefix_interp is not None:
output_prefix = _parse_prefix(self._destination_prefix_interp)
output_prefix = extract_ds_name_from_gcs_path(self._destination_prefix_interp)

return OperatorLineage(
inputs=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,6 @@ def execute_complete(self, context: Context, event: dict[str, Any]):

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will include final BQ job id."""
from pathlib import Path

from airflow.providers.common.compat.openlineage.facet import (
BaseFacet,
Dataset,
Expand All @@ -303,6 +301,8 @@ def get_openlineage_facets_on_complete(self, task_instance):
)
from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url
from airflow.providers.google.cloud.openlineage.utils import (
WILDCARD,
extract_ds_name_from_gcs_path,
get_facets_from_bq_table,
get_identity_column_lineage_facet,
)
Expand Down Expand Up @@ -333,24 +333,19 @@ def get_openlineage_facets_on_complete(self, task_instance):
output_datasets = []
for uri in sorted(self.destination_cloud_storage_uris):
bucket, blob = _parse_gcs_url(uri)
additional_facets = {}

if "*" in blob:
# If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name,
# but we create a symlink to the full object path with wildcard.
additional_facets = {}
if WILDCARD in blob:
# For path with wildcard we attach a symlink with unmodified path.
additional_facets = {
"symlink": SymlinksDatasetFacet(
identifiers=[Identifier(namespace=f"gs://{bucket}", name=blob, type="file")]
),
}
blob = Path(blob).parent.as_posix()
if blob == ".":
# blob path does not have leading slash, but we need root dataset name to be "/"
blob = "/"

dataset = Dataset(
namespace=f"gs://{bucket}",
name=blob,
name=extract_ds_name_from_gcs_path(blob),
facets=merge_dicts(output_dataset_facets, additional_facets),
)
output_datasets.append(dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -756,15 +756,15 @@ def on_kill(self) -> None:

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will include final BQ job id."""
from pathlib import Path

from airflow.providers.common.compat.openlineage.facet import (
Dataset,
ExternalQueryRunFacet,
Identifier,
SymlinksDatasetFacet,
)
from airflow.providers.google.cloud.openlineage.utils import (
WILDCARD,
extract_ds_name_from_gcs_path,
get_facets_from_bq_table,
get_identity_column_lineage_facet,
)
Expand Down Expand Up @@ -793,22 +793,17 @@ def get_openlineage_facets_on_complete(self, task_instance):
for blob in sorted(source_objects):
additional_facets = {}

if "*" in blob:
# If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name,
# but we create a symlink to the full object path with wildcard.
if WILDCARD in blob:
# For path with wildcard we attach a symlink with unmodified path.
additional_facets = {
"symlink": SymlinksDatasetFacet(
identifiers=[Identifier(namespace=f"gs://{self.bucket}", name=blob, type="file")]
),
}
blob = Path(blob).parent.as_posix()
if blob == ".":
# blob path does not have leading slash, but we need root dataset name to be "/"
blob = "/"

dataset = Dataset(
namespace=f"gs://{self.bucket}",
name=blob,
name=extract_ds_name_from_gcs_path(blob),
facets=merge_dicts(input_dataset_facets, additional_facets),
)
input_datasets.append(dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,28 +551,16 @@ def get_openlineage_facets_on_complete(self, task_instance):
This means we won't have to normalize self.source_object and self.source_objects,
destination bucket and so on.
"""
from pathlib import Path

from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path
from airflow.providers.openlineage.extractors import OperatorLineage

def _process_prefix(pref):
if WILDCARD in pref:
pref = pref.split(WILDCARD)[0]
# Use parent if not a file (dot not in name) and not a dir (ends with slash)
if "." not in pref.split("/")[-1] and not pref.endswith("/"):
pref = Path(pref).parent.as_posix()
return ["/" if pref in ("", "/", ".") else pref.rstrip("/")] # Adjust root path

inputs = []
for prefix in self.source_objects:
result = _process_prefix(prefix)
inputs.extend(result)
inputs = [extract_ds_name_from_gcs_path(path) for path in self.source_objects]

if self.destination_object is None:
outputs = inputs.copy()
else:
outputs = _process_prefix(self.destination_object)
outputs = [extract_ds_name_from_gcs_path(self.destination_object)]

return OperatorLineage(
inputs=[
Expand Down
25 changes: 25 additions & 0 deletions providers/tests/google/cloud/openlineage/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
from unittest.mock import MagicMock

import pytest
from google.cloud.bigquery.table import Table

from airflow.providers.common.compat.openlineage.facet import (
Expand All @@ -31,6 +32,7 @@
SchemaDatasetFacetFields,
)
from airflow.providers.google.cloud.openlineage.utils import (
extract_ds_name_from_gcs_path,
get_facets_from_bq_table,
get_identity_column_lineage_facet,
)
Expand Down Expand Up @@ -263,3 +265,26 @@ def test_get_identity_column_lineage_facet_no_input_datasets():

result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets)
assert result == {}


@pytest.mark.parametrize(
"input_path, expected_output",
[
("/path/to/file.txt", "path/to/file.txt"), # Full file path
("file.txt", "file.txt"), # File path in root directory
("/path/to/dir/", "path/to/dir"), # Directory path
("/path/to/dir/*", "path/to/dir"), # Path with wildcard at the end
("/path/to/dir/*.csv", "path/to/dir"), # Path with wildcard in file name
("/path/to/dir/file.*", "path/to/dir"), # Path with wildcard in file extension
("/path/to/*/dir/file.csv", "path/to"), # Path with wildcard in the middle
("/path/to/dir/pre_", "path/to/dir"), # Path with prefix
("/pre", "/"), # Prefix only
("/*", "/"), # Wildcard after root slash
("/", "/"), # Root path
("", "/"), # Empty path
(".", "/"), # Current directory
("*", "/"), # Wildcard only
],
)
def test_extract_ds_name_from_gcs_path(input_path, expected_output):
assert extract_ds_name_from_gcs_path(input_path) == expected_output