Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add OpenLineage support for transfer operators between gcs and local #44417

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion providers/src/airflow/providers/common/io/assets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@ def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDa
from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset

parsed = urllib.parse.urlsplit(asset.uri)
return OpenLineageDataset(namespace=f"file://{parsed.netloc}", name=parsed.path)
return OpenLineageDataset(
namespace=f"file://{parsed.netloc}" if parsed.netloc else "file", name=parsed.path
)
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,12 @@ def execute(self, context: Context):
raise AirflowException("The size of the downloaded file is too large to push to XCom!")
else:
hook.download(bucket_name=self.bucket, object_name=self.object_name, filename=self.filename)

def get_openlineage_facets_on_start(self):
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.openlineage.extractors import OperatorLineage

return OperatorLineage(
inputs=[Dataset(namespace=f"gs://{self.bucket}", name=self.object_name)],
outputs=[Dataset(namespace="file", name=self.filename)] if self.filename else [],
)
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ class LocalFilesystemToGCSOperator(BaseOperator):
def __init__(
self,
*,
src,
dst,
bucket,
gcp_conn_id="google_cloud_default",
mime_type="application/octet-stream",
gzip=False,
src: str | list[str],
dst: str,
bucket: str,
gcp_conn_id: str = "google_cloud_default",
mime_type: str = "application/octet-stream",
gzip: bool = False,
chunk_size: int | None = None,
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
Expand Down Expand Up @@ -120,3 +120,38 @@ def execute(self, context: Context):
gzip=self.gzip,
chunk_size=self.chunk_size,
)

def get_openlineage_facets_on_start(self):
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
Identifier,
SymlinksDatasetFacet,
)
from airflow.providers.google.cloud.openlineage.utils import WILDCARD, extract_ds_name_from_gcs_path
from airflow.providers.openlineage.extractors import OperatorLineage

source_facets = {}
if isinstance(self.src, str): # Single path provided, possibly relative or with wildcard
original_src = f"{self.src}"
absolute_src = os.path.abspath(self.src)
resolved_src = extract_ds_name_from_gcs_path(absolute_src)
if original_src.startswith("/") and not resolved_src.startswith("/"):
resolved_src = "/" + resolved_src
source_objects = [resolved_src]

if WILDCARD in original_src or absolute_src != resolved_src:
# We attach a symlink with unmodified path.
source_facets = {
"symlink": SymlinksDatasetFacet(
identifiers=[Identifier(namespace="file", name=original_src, type="file")]
),
}
else:
source_objects = self.src

dest_object = self.dst if os.path.basename(self.dst) else extract_ds_name_from_gcs_path(self.dst)

return OperatorLineage(
inputs=[Dataset(namespace="file", name=src, facets=source_facets) for src in source_objects],
outputs=[Dataset(namespace=f"gs://{self.bucket}", name=dest_object)],
)
4 changes: 2 additions & 2 deletions providers/tests/common/io/assets/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def test_file_asset():
@pytest.mark.parametrize(
("uri", "ol_dataset"),
(
("file:///valid/path", OpenLineageDataset(namespace="file://", name="/valid/path")),
("file:///valid/path", OpenLineageDataset(namespace="file", name="/valid/path")),
(
"file://127.0.0.1:8080/dir/file.csv",
OpenLineageDataset(namespace="file://127.0.0.1:8080", name="/dir/file.csv"),
),
("file:///C://dir/file", OpenLineageDataset(namespace="file://", name="/C://dir/file")),
("file:///C://dir/file", OpenLineageDataset(namespace="file", name="/C://dir/file")),
),
)
def test_convert_asset_to_openlineage(uri, ol_dataset):
Expand Down
17 changes: 17 additions & 0 deletions providers/tests/google/cloud/transfers/test_gcs_to_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,20 @@ def test_xcom_encoding(self, mock_hook):
bucket_name=TEST_BUCKET, object_name=TEST_OBJECT
)
context["ti"].xcom_push.assert_called_once_with(key=XCOM_KEY, value=FILE_CONTENT_STR)

def test_get_openlineage_facets_on_start_(self):
operator = GCSToLocalFilesystemOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
object_name=TEST_OBJECT,
filename=LOCAL_FILE_PATH,
)
result = operator.get_openlineage_facets_on_start()
assert not result.job_facets
assert not result.run_facets
assert len(result.outputs) == 1
assert len(result.inputs) == 1
assert result.outputs[0].namespace == "file"
assert result.outputs[0].name == LOCAL_FILE_PATH
assert result.inputs[0].namespace == f"gs://{TEST_BUCKET}"
assert result.inputs[0].name == TEST_OBJECT
85 changes: 76 additions & 9 deletions providers/tests/google/cloud/transfers/test_local_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import pytest

from airflow.models.dag import DAG
from airflow.providers.common.compat.openlineage.facet import (
Identifier,
SymlinksDatasetFacet,
)
from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -72,7 +76,7 @@ def test_init(self):
def test_execute(self, mock_hook):
mock_instance = mock_hook.return_value
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor",
task_id="file_to_gcs_operator",
dag=self.dag,
src=self.testfile1,
dst="test/test1.csv",
Expand All @@ -91,7 +95,7 @@ def test_execute(self, mock_hook):
@pytest.mark.db_test
def test_execute_with_empty_src(self):
operator = LocalFilesystemToGCSOperator(
task_id="local_to_sensor",
task_id="file_to_gcs_operator",
dag=self.dag,
src="no_file.txt",
dst="test/no_file.txt",
Expand All @@ -104,7 +108,7 @@ def test_execute_with_empty_src(self):
def test_execute_multiple(self, mock_hook):
mock_instance = mock_hook.return_value
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor", dag=self.dag, src=self.testfiles, dst="test/", **self._config
task_id="file_to_gcs_operator", dag=self.dag, src=self.testfiles, dst="test/", **self._config
)
operator.execute(None)
files_objects = zip(
Expand All @@ -127,7 +131,7 @@ def test_execute_multiple(self, mock_hook):
def test_execute_wildcard(self, mock_hook):
mock_instance = mock_hook.return_value
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor", dag=self.dag, src="/tmp/fake*.csv", dst="test/", **self._config
task_id="file_to_gcs_operator", dag=self.dag, src="/tmp/fake*.csv", dst="test/", **self._config
)
operator.execute(None)
object_names = ["test/" + os.path.basename(fp) for fp in glob("/tmp/fake*.csv")]
Expand All @@ -145,17 +149,80 @@ def test_execute_wildcard(self, mock_hook):
]
mock_instance.upload.assert_has_calls(calls)

@pytest.mark.parametrize(
("src", "dst"),
[
("/tmp/fake*.csv", "test/test1.csv"),
("/tmp/fake*.csv", "test"),
("/tmp/fake*.csv", "test/dir"),
],
)
@mock.patch("airflow.providers.google.cloud.transfers.local_to_gcs.GCSHook", autospec=True)
def test_execute_negative(self, mock_hook):
def test_execute_negative(self, mock_hook, src, dst):
mock_instance = mock_hook.return_value
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor",
task_id="file_to_gcs_operator",
dag=self.dag,
src="/tmp/fake*.csv",
dst="test/test1.csv",
src=src,
dst=dst,
**self._config,
)
print(glob("/tmp/fake*.csv"))
with pytest.raises(ValueError):
operator.execute(None)
mock_instance.assert_not_called()

@pytest.mark.parametrize(
("src", "dst", "expected_input", "expected_output", "symlink"),
[
("/tmp/fake*.csv", "test/", "/tmp", "test", True),
("/tmp/../tmp/fake*.csv", "test/", "/tmp", "test", True),
("/tmp/fake1.csv", "test/test1.csv", "/tmp/fake1.csv", "test/test1.csv", False),
("/tmp/fake1.csv", "test/pre", "/tmp/fake1.csv", "test/pre", False),
],
)
def test_get_openlineage_facets_on_start_with_string_src(
self, src, dst, expected_input, expected_output, symlink
):
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor",
dag=self.dag,
src=src,
dst=dst,
**self._config,
)
result = operator.get_openlineage_facets_on_start()
assert not result.job_facets
assert not result.run_facets
assert len(result.outputs) == 1
assert len(result.inputs) == 1
assert result.outputs[0].name == expected_output
assert result.inputs[0].name == expected_input
if symlink:
assert result.inputs[0].facets["symlink"] == SymlinksDatasetFacet(
identifiers=[Identifier(namespace="file", name=src, type="file")]
)

@pytest.mark.parametrize(
("src", "dst", "expected_inputs", "expected_output"),
[
(["/tmp/fake1.csv", "/tmp/fake2.csv"], "test/", ["/tmp/fake1.csv", "/tmp/fake2.csv"], "test"),
(["/tmp/fake1.csv", "/tmp/fake2.csv"], "", ["/tmp/fake1.csv", "/tmp/fake2.csv"], "/"),
],
)
def test_get_openlineage_facets_on_start_with_list_src(self, src, dst, expected_inputs, expected_output):
operator = LocalFilesystemToGCSOperator(
task_id="gcs_to_file_sensor",
dag=self.dag,
src=src,
dst=dst,
**self._config,
)
result = operator.get_openlineage_facets_on_start()
assert not result.job_facets
assert not result.run_facets
assert len(result.outputs) == 1
assert len(result.inputs) == len(expected_inputs)
assert result.outputs[0].name == expected_output
assert result.outputs[0].namespace == "gs://dummy"
assert all(inp.name in expected_inputs for inp in result.inputs)
assert all(inp.namespace == "file" for inp in result.inputs)