From 489ca1494691f6a0b717f9312e0ab8a2a7f76d96 Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jens-scheffler-bosch@users.noreply.github.com> Date: Wed, 16 Aug 2023 08:04:52 +0200 Subject: [PATCH] Fix handling of default value and serialization of Param class (#33141) --- airflow/models/param.py | 9 ++++-- airflow/serialization/enums.py | 1 + airflow/serialization/serialized_objects.py | 5 +++ tests/models/test_param.py | 31 +++++++++++++++++++ tests/serialization/test_dag_serialization.py | 7 +++-- 5 files changed, 48 insertions(+), 5 deletions(-) diff --git a/airflow/models/param.py b/airflow/models/param.py index 82a49f715d8ac..bea4333cf5105 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -138,13 +138,18 @@ def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: def dump(self) -> dict: """Dump the Param as a dictionary.""" - out_dict = {self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"} + out_dict: dict[str, str | None] = { + self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}" + } out_dict.update(self.__dict__) + # Ensure that not set is translated to None + if self.value is NOTSET: + out_dict["value"] = None return out_dict @property def has_value(self) -> bool: - return self.value is not NOTSET + return self.value is not NOTSET and self.value is not None def serialize(self) -> dict: return {"value": self.value, "description": self.description, "schema": self.schema} diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index c83d9f53ef87c..0b1c0ca009c35 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -55,3 +55,4 @@ class DagAttributeTypes(str, Enum): TASK_INSTANCE = "task_instance" DAG_RUN = "dag_run" DATA_SET = "data_set" + ARG_NOT_SET = "arg_not_set" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 17a084f20ec2d..67d08b7a94fd8 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -65,6 +65,7 @@ from airflow.utils.module_loading import import_string, qualname from airflow.utils.operator_resources import Resources from airflow.utils.task_group import MappedTaskGroup, TaskGroup +from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from pydantic import BaseModel @@ -499,6 +500,8 @@ def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any] return cls._encode(_pydantic_model_dump(DatasetPydantic, var), type_=DAT.DATA_SET) else: return cls.default_serialization(strict, var) + elif isinstance(var, ArgNotSet): + return cls._encode(None, type_=DAT.ARG_NOT_SET) else: return cls.default_serialization(strict, var) @@ -572,6 +575,8 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return DagRunPydantic.parse_obj(var) elif type_ == DAT.DATA_SET: return DatasetPydantic.parse_obj(var) + elif type_ == DAT.ARG_NOT_SET: + return NOTSET else: raise TypeError(f"Invalid type {type_!s} in deserialization.") diff --git a/tests/models/test_param.py b/tests/models/test_param.py index 4053cf6571716..b73cfea15f4ee 100644 --- a/tests/models/test_param.py +++ b/tests/models/test_param.py @@ -23,6 +23,7 @@ from airflow.decorators import task from airflow.exceptions import ParamValidationError, RemovedInAirflow3Warning from airflow.models.param import Param, ParamsDict +from airflow.serialization.serialized_objects import BaseSerialization from airflow.utils import timezone from airflow.utils.types import DagRunType from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom @@ -41,14 +42,20 @@ def test_null_param(self): with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"): p.resolve() assert p.resolve(None) is None + assert p.dump()["value"] is None + assert not p.has_value p = Param(None) assert p.resolve() is None assert p.resolve(None) is None + assert p.dump()["value"] is None + assert not p.has_value p = Param(None, type="null") assert p.resolve() is None assert p.resolve(None) is None + assert p.dump()["value"] is None + assert not p.has_value with pytest.raises(ParamValidationError): p.resolve("test") @@ -222,6 +229,30 @@ def test_dump(self): assert dump["description"] == "world" assert dump["schema"] == {"type": "string", "minLength": 2} + @pytest.mark.parametrize( + "param", + [ + Param("my value", description="hello", schema={"type": "string"}), + Param("my value", description="hello"), + Param(None, description=None), + Param([True], type="array", items={"type": "boolean"}), + Param(), + ], + ) + def test_param_serialization(self, param: Param): + """ + Test to make sure that native Param objects can be correctly serialized + """ + + serializer = BaseSerialization() + serialized_param = serializer.serialize(param) + restored_param: Param = serializer.deserialize(serialized_param) + + assert restored_param.value == param.value + assert isinstance(restored_param, Param) + assert restored_param.description == param.description + assert restored_param.schema == param.schema + class TestParamsDict: def test_params_dict(self): diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 26474a8929882..5338579e01044 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -899,20 +899,21 @@ def __init__(self, path: str): Param("my value", description="hello"), Param(None, description=None), Param([True], type="array", items={"type": "boolean"}), + Param(), ], ) - def test_full_param_roundtrip(self, param): + def test_full_param_roundtrip(self, param: Param): """ Test to make sure that only native Param objects are being passed as dag or task params """ - dag = DAG(dag_id="simple_dag", params={"my_param": param}) + dag = DAG(dag_id="simple_dag", schedule=None, params={"my_param": param}) serialized_json = SerializedDAG.to_json(dag) serialized = json.loads(serialized_json) SerializedDAG.validate_schema(serialized) dag = SerializedDAG.from_dict(serialized) - assert dag.params["my_param"] == param.value + assert dag.params.get_param("my_param").value == param.value observed_param = dag.params.get_param("my_param") assert isinstance(observed_param, Param) assert observed_param.description == param.description