Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow datasets to be used in taskflow #27540

Merged
merged 3 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Any
from typing import Any, ClassVar
from urllib.parse import urlsplit

import attr
Expand All @@ -29,6 +29,8 @@ class Dataset:
uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)])
extra: dict[str, Any] | None = None

version: ClassVar[int] = 1

@uri.validator
def _check_uri(self, attr, uri: str):
if uri.isspace():
Expand Down
2 changes: 1 addition & 1 deletion airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def register_dataset_change(
"""
dataset_model = session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).one_or_none()
if not dataset_model:
self.log.warning("DatasetModel %s not found", dataset_model)
self.log.warning("DatasetModel %s not found", dataset)
return
session.add(
DatasetEvent(
Expand Down
19 changes: 18 additions & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import inspect
import re
from itertools import chain
from typing import (
Any,
Callable,
Expand All @@ -37,6 +38,7 @@
import typing_extensions
from sqlalchemy.orm import Session

from airflow import Dataset
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY
Expand Down Expand Up @@ -207,17 +209,30 @@ def __init__(
super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)

def execute(self, context: Context):
# todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators
# as well
for arg in chain(self.op_args, self.op_kwargs.values()):
if isinstance(arg, Dataset):
self.inlets.append(arg)
bolkedebruin marked this conversation as resolved.
Show resolved Hide resolved
return_value = super().execute(context)
return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push)

def _handle_output(self, return_value: Any, context: Context, xcom_push: Callable):
"""
Handles logic for whether a decorator needs to push a single return value or multiple return values.

It sets outlets if any datasets are found in the returned value(s)
bolkedebruin marked this conversation as resolved.
Show resolved Hide resolved

:param return_value:
:param context:
:param xcom_push:
"""
if isinstance(return_value, Dataset):
self.outlets.append(return_value)
if isinstance(return_value, list):
for item in return_value:
if isinstance(item, Dataset):
self.outlets.append(item)
if not self.multiple_outputs:
return return_value
if isinstance(return_value, dict):
Expand All @@ -228,6 +243,8 @@ def _handle_output(self, return_value: Any, context: Context, xcom_push: Callabl
f"multiple_outputs, found {key} ({type(key)}) instead"
)
for key, value in return_value.items():
if isinstance(value, Dataset):
self.outlets.append(value)
xcom_push(context, key, value)
else:
raise AirflowException(
Expand Down Expand Up @@ -280,7 +297,7 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubcla
def _infer_multiple_outputs(self):
try:
return_type = typing_extensions.get_type_hints(self.function).get("return", Any)
except Exception: # Can't evaluate retrurn type.
except TypeError: # Can't evaluate return type.
return False
ttype = getattr(return_type, "__origin__", return_type)
return ttype == dict or ttype == Dict
Expand Down
41 changes: 5 additions & 36 deletions airflow/lineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast

import attr

from airflow.configuration import conf
from airflow.lineage.backend import LineageBackend
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -60,36 +57,6 @@ def _render_object(obj: Any, context: Context) -> dict:
return context["ti"].task.render_template(obj, context)


def _deserialize(serialized: dict):
from airflow.serialization.serialized_objects import BaseSerialization

# This is only use in the worker side, so it is okay to "blindly" import the specified class here.
cls = import_string(serialized["__type"])
return cls(**BaseSerialization.deserialize(serialized["__var"]))


def _serialize(objs: list[Any], source: str):
"""Serialize an attrs-decorated class to JSON."""
from airflow.serialization.serialized_objects import BaseSerialization

for obj in objs:
if not attr.has(obj):
continue

type_name = obj.__module__ + "." + obj.__class__.__name__
# Only include attributes which we can pass back to the classes constructor
data = attr.asdict(obj, recurse=True, filter=lambda a, v: a.init)

yield {
k: BaseSerialization.serialize(v)
for k, v in (
("__type", type_name),
("__source", source),
("__var", data),
)
}


T = TypeVar("T", bound=Callable)


Expand All @@ -106,10 +73,11 @@ def apply_lineage(func: T) -> T:
def wrapper(self, context, *args, **kwargs):

self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets)

ret_val = func(self, context, *args, **kwargs)

outlets = list(_serialize(self.outlets, f"{self.dag_id}.{self.task_id}"))
inlets = list(_serialize(self.inlets, None))
outlets = list(self.outlets)
inlets = list(self.inlets)

if outlets:
self.xcom_push(
Expand Down Expand Up @@ -169,7 +137,7 @@ def wrapper(self, context, *args, **kwargs):

# re-instantiate the obtained inlets
# xcom_pull returns a list of items for each given task_id
_inlets = [_deserialize(item) for item in itertools.chain.from_iterable(_inlets)]
_inlets = [item for item in itertools.chain.from_iterable(_inlets)]

self.inlets.extend(_inlets)

Expand All @@ -185,6 +153,7 @@ def wrapper(self, context, *args, **kwargs):
self.outlets = [_render_object(i, context) for i in self.outlets]

self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)

return func(self, context, *args, **kwargs)

return cast(T, wrapper)
29 changes: 20 additions & 9 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from airflow.utils import timezone
from airflow.utils.helpers import exactly_one, is_container
from airflow.utils.json import XComDecoder, XComEncoder
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
Expand Down Expand Up @@ -620,32 +621,42 @@ def serialize_value(
if conf.getboolean("core", "enable_xcom_pickling"):
return pickle.dumps(value)
try:
return json.dumps(value).encode("UTF-8")
except (ValueError, TypeError):
return json.dumps(value, cls=XComEncoder).encode("UTF-8")
except (ValueError, TypeError) as ex:
log.error(
"Could not serialize the XCom value into JSON."
"%s."
" If you are using pickle instead of JSON for XCom,"
" then you need to enable pickle support for XCom"
" in your airflow config."
" in your airflow config or make sure to decorate your"
" object with attr.",
ex
)
raise

@staticmethod
def deserialize_value(result: XCom) -> Any:
"""Deserialize XCom value from str or pickle object"""
def _deserialize_value(result: XCom, orm: bool) -> Any:
object_hook = None
if orm:
object_hook = XComDecoder.orm_object_hook

if result.value is None:
return None
if conf.getboolean("core", "enable_xcom_pickling"):
try:
return pickle.loads(result.value)
except pickle.UnpicklingError:
return json.loads(result.value.decode("UTF-8"))
return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
else:
try:
return json.loads(result.value.decode("UTF-8"))
return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
except (json.JSONDecodeError, UnicodeDecodeError):
return pickle.loads(result.value)

@staticmethod
def deserialize_value(result: XCom) -> Any:
"""Deserialize XCom value from str or pickle object"""
return BaseXCom._deserialize_value(result, False)

def orm_deserialize_value(self) -> Any:
"""
Deserialize method which is used to reconstruct ORM XCom object.
Expand All @@ -655,7 +666,7 @@ def orm_deserialize_value(self) -> Any:
creating XCom orm model. This is used when viewing XCom listing
in the webserver, for example.
"""
return BaseXCom.deserialize_value(self)
return BaseXCom._deserialize_value(self, True)


class _LazyXComAccessIterator(collections.abc.Iterator):
Expand Down
127 changes: 126 additions & 1 deletion airflow/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
# under the License.
from __future__ import annotations

import dataclasses
import json
import logging
from datetime import date, datetime
from decimal import Decimal
from typing import Any

import attr
from flask.json.provider import JSONProvider

from airflow.serialization.enums import Encoding
from airflow.utils.module_loading import import_string
from airflow.utils.timezone import convert_to_utc, is_naive

try:
Expand All @@ -40,6 +45,16 @@

log = logging.getLogger(__name__)

CLASSNAME = "__classname__"
VERSION = "__version__"
DATA = "__data__"
bolkedebruin marked this conversation as resolved.
Show resolved Hide resolved

OLD_TYPE = "__type"
OLD_SOURCE = "__source"
OLD_DATA = "__var"

DEFAULT_VERSION = 0


class AirflowJsonEncoder(json.JSONEncoder):
"""Custom Airflow json encoder implementation."""
Expand Down Expand Up @@ -107,7 +122,7 @@ def safe_get_name(pod):
log.debug("traceback for pod JSON encode error", exc_info=True)
return {}

raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
raise TypeError(f"Object of type '{obj.__class__.__qualname__}' is not JSON serializable")


class AirflowJsonProvider(JSONProvider):
Expand All @@ -123,3 +138,113 @@ def dumps(self, obj, **kwargs):

def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)


# for now separate as AirflowJsonEncoder is non-standard
class XComEncoder(json.JSONEncoder):
"""This encoder serializes any object that has attr, dataclass or a custom serializer."""

def default(self, o: object) -> dict:
from airflow.serialization.serialized_objects import BaseSerialization

dct = {
CLASSNAME: o.__module__ + "." + o.__class__.__qualname__,
VERSION: getattr(o.__class__, "version", DEFAULT_VERSION),
}

if hasattr(o, "serialize"):
dct[DATA] = getattr(o, "serialize")()
return dct
elif dataclasses.is_dataclass(o.__class__):
data = dataclasses.asdict(o)
dct[DATA] = BaseSerialization.serialize(data)
return dct
elif attr.has(o.__class__):
# Only include attributes which we can pass back to the classes constructor
data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) # type: ignore[arg-type]
dct[DATA] = BaseSerialization.serialize(data)
return dct
else:
return super().default(o)

def encode(self, o: Any) -> str:
if isinstance(o, dict) and CLASSNAME in o:
raise AttributeError(f"reserved key {CLASSNAME} found in dict to serialize")

return super().encode(o)


class XComDecoder(json.JSONDecoder):
"""
This decoder deserializes dicts to objects if they contain
the `__classname__` key otherwise it will return the dict
as is.
"""

def __init__(self, *args, **kwargs) -> None:
if not kwargs.get("object_hook"):
kwargs["object_hook"] = self.object_hook

super().__init__(*args, **kwargs)

@staticmethod
def object_hook(dct: dict) -> object:
dct = XComDecoder._convert(dct)

if CLASSNAME in dct and VERSION in dct:
from airflow.serialization.serialized_objects import BaseSerialization

cls = import_string(dct[CLASSNAME])

if hasattr(cls, "deserialize"):
return getattr(cls, "deserialize")(dct[DATA], dct[VERSION])

version = getattr(cls, "version", 0)
if int(dct[VERSION]) > version:
raise TypeError(
"serialized version of %s is newer than module version (%s > %s)",
dct[CLASSNAME],
dct[VERSION],
version,
)

if not attr.has(cls) and not dataclasses.is_dataclass(cls):
raise TypeError(
f"cannot deserialize: no deserialization method "
f"for {dct[CLASSNAME]} and not attr/dataclass decorated"
)

return cls(**BaseSerialization.deserialize(dct[DATA]))

return dct

@staticmethod
def orm_object_hook(dct: dict) -> object:
"""Creates a readable representation of a serialized object"""
dct = XComDecoder._convert(dct)
if CLASSNAME in dct and VERSION in dct:
from airflow.serialization.serialized_objects import BaseSerialization

if Encoding.VAR in dct[DATA] and Encoding.TYPE in dct[DATA]:
data = BaseSerialization.deserialize(dct[DATA])
if not isinstance(data, dict):
raise TypeError(f"deserialized value should be a dict, but is {type(data)}")
else:
# custom serializer
data = dct[DATA]

s = f"{dct[CLASSNAME]}@version={dct[VERSION]}("
for k, v in data.items():
s += f"{k}={v},"
s = s[:-1] + ")"
return s

return dct

@staticmethod
def _convert(old: dict) -> dict:
"""Converts an old style serialization to new style"""
if OLD_TYPE in old and OLD_SOURCE in old:
return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]}

return old
Loading