Skip to content

Commit

Permalink
chore: unify handling of gcs paths (#44410)
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <mudakacper@gmail.com>
  • Loading branch information
kacpermuda authored Nov 27, 2024
1 parent 8d96728 commit 9bae4a5
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 50 deletions.
38 changes: 38 additions & 0 deletions providers/src/airflow/providers/google/cloud/openlineage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

import os
import pathlib
from typing import TYPE_CHECKING, Any

from attr import define, field
Expand All @@ -40,6 +42,42 @@

BIGQUERY_NAMESPACE = "bigquery"
BIGQUERY_URI = "bigquery"
WILDCARD = "*"


def extract_ds_name_from_gcs_path(path: str) -> str:
"""
Extract and process the dataset name from a given path.
Args:
path: The path to process e.g. of a gcs file.
Returns:
The processed dataset name.
"""
if WILDCARD in path:
path = path.split(WILDCARD, maxsplit=1)[0]

# We want to end up with parent directory if the path:
# - does not refer to a file (no dot in the last segment)
# and does not explicitly end with a slash, it is treated as a prefix and removed.
# Example: "/dir/pre_" -> "/dir/"
# - contains a dot at the end, then it is treated as a prefix (created after removing the wildcard).
# Example: "/dir/file." (was "/dir/file.*" with wildcard) -> "/dir/"
last_path_segment = os.path.basename(path).rstrip(".")
if "." not in last_path_segment and not path.endswith("/"):
path = pathlib.Path(path).parent.as_posix()

# Normalize the path:
# - Remove trailing slashes.
# - Remove leading slashes.
# - Handle edge cases for empty paths or single-dot paths.
path = path.rstrip("/")
path = path.lstrip("/")
if path in ("", "."):
path = "/"

return path


def get_facets_from_bq_table(table: Table) -> dict[str, BaseFacet]:
Expand Down
19 changes: 5 additions & 14 deletions providers/src/airflow/providers/google/cloud/operators/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,19 +343,15 @@ def get_openlineage_facets_on_start(self):
LifecycleStateChangeDatasetFacet,
PreviousIdentifier,
)
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path
from airflow.providers.openlineage.extractors import OperatorLineage

objects = []
if self.objects is not None:
objects = self.objects
elif self.prefix is not None:
prefixes = [self.prefix] if isinstance(self.prefix, str) else self.prefix
for pref in prefixes:
# Use parent if not a file (dot not in name) and not a dir (ends with slash)
if "." not in pref.split("/")[-1] and not pref.endswith("/"):
pref = Path(pref).parent.as_posix()
pref = "/" if pref in (".", "", "/") else pref.rstrip("/")
objects.append(pref)
objects = [extract_ds_name_from_gcs_path(pref) for pref in prefixes]

bucket_url = f"gs://{self.bucket_name}"
input_datasets = [
Expand Down Expand Up @@ -921,20 +917,15 @@ def execute(self, context: Context) -> list[str]:
def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as execute() resolves object prefixes."""
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path
from airflow.providers.openlineage.extractors import OperatorLineage

def _parse_prefix(pref):
# Use parent if not a file (dot not in name) and not a dir (ends with slash)
if "." not in pref.split("/")[-1] and not pref.endswith("/"):
pref = Path(pref).parent.as_posix()
return "/" if pref in (".", "/", "") else pref.rstrip("/")

input_prefix, output_prefix = "/", "/"
if self._source_prefix_interp is not None:
input_prefix = _parse_prefix(self._source_prefix_interp)
input_prefix = extract_ds_name_from_gcs_path(self._source_prefix_interp)

if self._destination_prefix_interp is not None:
output_prefix = _parse_prefix(self._destination_prefix_interp)
output_prefix = extract_ds_name_from_gcs_path(self._destination_prefix_interp)

return OperatorLineage(
inputs=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,6 @@ def execute_complete(self, context: Context, event: dict[str, Any]):

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will include final BQ job id."""
from pathlib import Path

from airflow.providers.common.compat.openlineage.facet import (
BaseFacet,
Dataset,
Expand All @@ -303,6 +301,8 @@ def get_openlineage_facets_on_complete(self, task_instance):
)
from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url
from airflow.providers.google.cloud.openlineage.utils import (
WILDCARD,
extract_ds_name_from_gcs_path,
get_facets_from_bq_table,
get_identity_column_lineage_facet,
)
Expand Down Expand Up @@ -333,24 +333,19 @@ def get_openlineage_facets_on_complete(self, task_instance):
output_datasets = []
for uri in sorted(self.destination_cloud_storage_uris):
bucket, blob = _parse_gcs_url(uri)
additional_facets = {}

if "*" in blob:
# If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name,
# but we create a symlink to the full object path with wildcard.
additional_facets = {}
if WILDCARD in blob:
# For path with wildcard we attach a symlink with unmodified path.
additional_facets = {
"symlink": SymlinksDatasetFacet(
identifiers=[Identifier(namespace=f"gs://{bucket}", name=blob, type="file")]
),
}
blob = Path(blob).parent.as_posix()
if blob == ".":
# blob path does not have leading slash, but we need root dataset name to be "/"
blob = "/"

dataset = Dataset(
namespace=f"gs://{bucket}",
name=blob,
name=extract_ds_name_from_gcs_path(blob),
facets=merge_dicts(output_dataset_facets, additional_facets),
)
output_datasets.append(dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -756,15 +756,15 @@ def on_kill(self) -> None:

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will include final BQ job id."""
from pathlib import Path

from airflow.providers.common.compat.openlineage.facet import (
Dataset,
ExternalQueryRunFacet,
Identifier,
SymlinksDatasetFacet,
)
from airflow.providers.google.cloud.openlineage.utils import (
WILDCARD,
extract_ds_name_from_gcs_path,
get_facets_from_bq_table,
get_identity_column_lineage_facet,
)
Expand Down Expand Up @@ -793,22 +793,17 @@ def get_openlineage_facets_on_complete(self, task_instance):
for blob in sorted(source_objects):
additional_facets = {}

if "*" in blob:
# If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name,
# but we create a symlink to the full object path with wildcard.
if WILDCARD in blob:
# For path with wildcard we attach a symlink with unmodified path.
additional_facets = {
"symlink": SymlinksDatasetFacet(
identifiers=[Identifier(namespace=f"gs://{self.bucket}", name=blob, type="file")]
),
}
blob = Path(blob).parent.as_posix()
if blob == ".":
# blob path does not have leading slash, but we need root dataset name to be "/"
blob = "/"

dataset = Dataset(
namespace=f"gs://{self.bucket}",
name=blob,
name=extract_ds_name_from_gcs_path(blob),
facets=merge_dicts(input_dataset_facets, additional_facets),
)
input_datasets.append(dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,28 +551,16 @@ def get_openlineage_facets_on_complete(self, task_instance):
This means we won't have to normalize self.source_object and self.source_objects,
destination bucket and so on.
"""
from pathlib import Path

from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path
from airflow.providers.openlineage.extractors import OperatorLineage

def _process_prefix(pref):
if WILDCARD in pref:
pref = pref.split(WILDCARD)[0]
# Use parent if not a file (dot not in name) and not a dir (ends with slash)
if "." not in pref.split("/")[-1] and not pref.endswith("/"):
pref = Path(pref).parent.as_posix()
return ["/" if pref in ("", "/", ".") else pref.rstrip("/")] # Adjust root path

inputs = []
for prefix in self.source_objects:
result = _process_prefix(prefix)
inputs.extend(result)
inputs = [extract_ds_name_from_gcs_path(path) for path in self.source_objects]

if self.destination_object is None:
outputs = inputs.copy()
else:
outputs = _process_prefix(self.destination_object)
outputs = [extract_ds_name_from_gcs_path(self.destination_object)]

return OperatorLineage(
inputs=[
Expand Down
25 changes: 25 additions & 0 deletions providers/tests/google/cloud/openlineage/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
from unittest.mock import MagicMock

import pytest
from google.cloud.bigquery.table import Table

from airflow.providers.common.compat.openlineage.facet import (
Expand All @@ -31,6 +32,7 @@
SchemaDatasetFacetFields,
)
from airflow.providers.google.cloud.openlineage.utils import (
extract_ds_name_from_gcs_path,
get_facets_from_bq_table,
get_identity_column_lineage_facet,
)
Expand Down Expand Up @@ -263,3 +265,26 @@ def test_get_identity_column_lineage_facet_no_input_datasets():

result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets)
assert result == {}


@pytest.mark.parametrize(
"input_path, expected_output",
[
("/path/to/file.txt", "path/to/file.txt"), # Full file path
("file.txt", "file.txt"), # File path in root directory
("/path/to/dir/", "path/to/dir"), # Directory path
("/path/to/dir/*", "path/to/dir"), # Path with wildcard at the end
("/path/to/dir/*.csv", "path/to/dir"), # Path with wildcard in file name
("/path/to/dir/file.*", "path/to/dir"), # Path with wildcard in file extension
("/path/to/*/dir/file.csv", "path/to"), # Path with wildcard in the middle
("/path/to/dir/pre_", "path/to/dir"), # Path with prefix
("/pre", "/"), # Prefix only
("/*", "/"), # Wildcard after root slash
("/", "/"), # Root path
("", "/"), # Empty path
(".", "/"), # Current directory
("*", "/"), # Wildcard only
],
)
def test_extract_ds_name_from_gcs_path(input_path, expected_output):
assert extract_ds_name_from_gcs_path(input_path) == expected_output

0 comments on commit 9bae4a5

Please sign in to comment.