Skip to content

Commit

Permalink
Remove cattrs from lineage processing.
Browse files Browse the repository at this point in the history
Cattrs was used for two reasons:

1. As a hacky way of forcing templated fields on classes
2. As a way to store the outlets in XCom without needing pickle

1 has been fixed in core for a while now and classes can have
  "template_fields" properties (deeply)

2 is now done by using a combo of BaseSerialization and `attr.asdict`
  • Loading branch information
ashb committed Sep 5, 2022
1 parent 5264b63 commit 55c9ee3
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 47 deletions.
87 changes: 42 additions & 45 deletions airflow/lineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@
# specific language governing permissions and limitations
# under the License.
"""Provides lineage support functions"""
import json
import itertools
import logging
from functools import wraps
from typing import Any, Callable, Dict, Optional, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, cast

import attr
import jinja2
from cattr import structure, unstructure

from airflow.configuration import conf
from airflow.lineage.backend import LineageBackend
from airflow.utils.context import lazy_mapping_from_context
from airflow.utils.module_loading import import_string

ENV = jinja2.Environment()
if TYPE_CHECKING:
from airflow.utils.context import Context


PIPELINE_OUTLETS = "pipeline_outlets"
PIPELINE_INLETS = "pipeline_inlets"
Expand All @@ -39,15 +38,6 @@
log = logging.getLogger(__name__)


@attr.s(auto_attribs=True)
class Metadata:
"""Class for serialized entities."""

type_name: str = attr.ib()
source: str = attr.ib()
data: Dict = attr.ib()


def get_backend() -> Optional[LineageBackend]:
"""Gets the lineage backend if defined in the configs"""
clazz = conf.getimport("lineage", "backend", fallback=None)
Expand All @@ -64,33 +54,38 @@ def get_backend() -> Optional[LineageBackend]:
return None


def _get_instance(meta: Metadata):
"""Instantiate an object from Metadata"""
cls = import_string(meta.type_name)
return structure(meta.data, cls)
def _render_object(obj: Any, context: "Context") -> dict:
return context['ti'].task.render_template(obj, context)


def _render_object(obj: Any, context) -> Any:
"""Renders a attr annotated object. Will set non serializable attributes to none"""
return structure(
json.loads(
ENV.from_string(json.dumps(unstructure(obj), default=lambda o: None))
.render(lazy_mapping_from_context(context))
.encode('utf-8')
),
type(obj),
)
def _deserialize(serialized: dict):
from airflow.serialization.serialized_objects import BaseSerialization

# This is only use in the worker side, so it is okay to "blindly" import the specified class here.
cls = import_string(serialized['__type'])
return cls(**BaseSerialization.deserialize(serialized['__var']))

def _to_dataset(obj: Any, source: str) -> Optional[Metadata]:
"""Create Metadata from attr annotated object"""
if not attr.has(obj):
return None

type_name = obj.__module__ + '.' + obj.__class__.__name__
data = unstructure(obj)
def _serialize(objs: List[Any], source: str):
"""Serialize an attrs-decorated class to JSON"""
from airflow.serialization.serialized_objects import BaseSerialization

return Metadata(type_name, source, data)
for obj in objs:
if not attr.has(obj):
continue

type_name = obj.__module__ + '.' + obj.__class__.__name__
# Only include attributes which we can pass back to the classes constructor
data = attr.asdict(obj, recurse=True, filter=lambda a, v: a.init)

yield {
k: BaseSerialization.serialize(v)
for k, v in (
('__type', type_name),
('__source', source),
('__var', data),
)
}


T = TypeVar("T", bound=Callable)
Expand All @@ -105,18 +100,19 @@ def apply_lineage(func: T) -> T:

@wraps(func)
def wrapper(self, context, *args, **kwargs):

self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets)
ret_val = func(self, context, *args, **kwargs)

outlets = [unstructure(_to_dataset(x, f"{self.dag_id}.{self.task_id}")) for x in self.outlets]
inlets = [unstructure(_to_dataset(x, None)) for x in self.inlets]
outlets = list(_serialize(self.outlets, f"{self.dag_id}.{self.task_id}"))
inlets = list(_serialize(self.inlets, None))

if self.outlets:
if outlets:
self.xcom_push(
context, key=PIPELINE_OUTLETS, value=outlets, execution_date=context['ti'].execution_date
)

if self.inlets:
if inlets:
self.xcom_push(
context, key=PIPELINE_INLETS, value=inlets, execution_date=context['ti'].execution_date
)
Expand Down Expand Up @@ -161,12 +157,13 @@ def wrapper(self, context, *args, **kwargs):
if AUTO.upper() in self.inlets or AUTO.lower() in self.inlets:
task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids))

# Remove auto and task_ids
self.inlets = [i for i in self.inlets if not isinstance(i, str)]
_inlets = self.xcom_pull(context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS)

# re-instantiate the obtained inlets
_inlets = [
_get_instance(structure(item, Metadata)) for sublist in _inlets if sublist for item in sublist
]
# xcom_pull returns a list of items for each given task_id
_inlets = [_deserialize(item) for item in itertools.chain.from_iterable(_inlets)]

self.inlets.extend(_inlets)

Expand All @@ -177,9 +174,9 @@ def wrapper(self, context, *args, **kwargs):
self.outlets = [self.outlets]

# render inlets and outlets
self.inlets = [_render_object(i, context) for i in self.inlets if attr.has(i)]
self.inlets = [_render_object(i, context) for i in self.inlets]

self.outlets = [_render_object(i, context) for i in self.outlets if attr.has(i)]
self.outlets = [_render_object(i, context) for i in self.outlets]

self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
return func(self, context, *args, **kwargs)
Expand Down
21 changes: 20 additions & 1 deletion airflow/lineage/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Defines the base entities that can be used for providing lineage
information.
"""
from typing import Any, Dict, List, Optional
from typing import Any, ClassVar, Dict, List, Optional

import attr

Expand All @@ -29,6 +29,8 @@
class File:
"""File entity. Refers to a file"""

template_fields: ClassVar = ("url",)

url: str = attr.ib()
type_hint: Optional[str] = None

Expand All @@ -41,13 +43,17 @@ class User:
first_name: Optional[str] = None
last_name: Optional[str] = None

template_fields: ClassVar = ("email", "first_name", "last_name")


@attr.s(auto_attribs=True, kw_only=True)
class Tag:
"""Tag or classification entity."""

tag_name: str = attr.ib()

template_fields: ClassVar = ("tag_name",)


@attr.s(auto_attribs=True, kw_only=True)
class Column:
Expand All @@ -58,6 +64,8 @@ class Column:
data_type: str = attr.ib()
tags: List[Tag] = []

template_fields: ClassVar = ("name", "description", "data_type", "tags")


# this is a temporary hack to satisfy mypy. Once
# https://github.com/python/mypy/issues/6136 is resolved, use
Expand All @@ -81,3 +89,14 @@ class Table:
owners: List[User] = []
extra: Dict[str, Any] = {}
type_hint: Optional[str] = None

template_fields: ClassVar = (
"database",
"cluster",
"name",
"tags",
"description",
"columns",
"owners",
"extra",
)
35 changes: 34 additions & 1 deletion tests/lineage/test_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,16 @@ def test_lineage(self, dag_maker):

op3.pre_execute(ctx3)
assert len(op3.inlets) == 1
assert isinstance(op3.inlets[0], File)
assert op3.inlets[0].url == f2s.format(DEFAULT_DATE)
assert op3.outlets[0] == file3
op3.post_execute(ctx3)

# skip 4

op5.pre_execute(ctx5)
assert len(op5.inlets) == 2
# Task IDs should be removed from the inlets, replaced with the outlets of those tasks
assert sorted(op5.inlets) == [file2, file3]
op5.post_execute(ctx5)

def test_lineage_render(self, dag_maker):
Expand All @@ -118,6 +120,37 @@ def test_lineage_render(self, dag_maker):
assert op1.inlets[0].url == f1s.format(DEFAULT_DATE)
assert op1.outlets[0].url == f1s.format(DEFAULT_DATE)

def test_non_attr_outlet(self, dag_maker):
class A:
pass

a = A()

f3s = "/tmp/does_not_exist_3"
file3 = File(f3s)

with dag_maker(dag_id='test_prepare_lineage'):
op1 = EmptyOperator(
task_id='leave1',
outlets=[a, file3],
)
op2 = EmptyOperator(task_id='leave2', inlets='auto')

op1 >> op2

dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)

ctx1 = Context({"ti": TI(task=op1, run_id=dag_run.run_id), "ds": DEFAULT_DATE})
ctx2 = Context({"ti": TI(task=op2, run_id=dag_run.run_id), "ds": DEFAULT_DATE})

# prepare with manual inlets and outlets
op1.pre_execute(ctx1)
op1.post_execute(ctx1)

op2.pre_execute(ctx2)
assert op2.inlets == [file3]
op2.post_execute(ctx2)

@mock.patch("airflow.lineage.get_backend")
def test_lineage_is_sent_to_backend(self, mock_get_backend, dag_maker):
class TestBackend(LineageBackend):
Expand Down

0 comments on commit 55c9ee3

Please sign in to comment.