Skip to content

Commit

Permalink
Fix handling of default value and serialization of Param class (#33141)
Browse files Browse the repository at this point in the history
  • Loading branch information
jscheffl authored Aug 16, 2023
1 parent 6547b81 commit 489ca14
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 5 deletions.
9 changes: 7 additions & 2 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,18 @@ def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:

def dump(self) -> dict:
"""Dump the Param as a dictionary."""
out_dict = {self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"}
out_dict: dict[str, str | None] = {
self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"
}
out_dict.update(self.__dict__)
# Ensure that not set is translated to None
if self.value is NOTSET:
out_dict["value"] = None
return out_dict

@property
def has_value(self) -> bool:
return self.value is not NOTSET
return self.value is not NOTSET and self.value is not None

def serialize(self) -> dict:
return {"value": self.value, "description": self.description, "schema": self.schema}
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ class DagAttributeTypes(str, Enum):
TASK_INSTANCE = "task_instance"
DAG_RUN = "dag_run"
DATA_SET = "data_set"
ARG_NOT_SET = "arg_not_set"
5 changes: 5 additions & 0 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from pydantic import BaseModel
Expand Down Expand Up @@ -499,6 +500,8 @@ def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]
return cls._encode(_pydantic_model_dump(DatasetPydantic, var), type_=DAT.DATA_SET)
else:
return cls.default_serialization(strict, var)
elif isinstance(var, ArgNotSet):
return cls._encode(None, type_=DAT.ARG_NOT_SET)
else:
return cls.default_serialization(strict, var)

Expand Down Expand Up @@ -572,6 +575,8 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return DagRunPydantic.parse_obj(var)
elif type_ == DAT.DATA_SET:
return DatasetPydantic.parse_obj(var)
elif type_ == DAT.ARG_NOT_SET:
return NOTSET
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")

Expand Down
31 changes: 31 additions & 0 deletions tests/models/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from airflow.decorators import task
from airflow.exceptions import ParamValidationError, RemovedInAirflow3Warning
from airflow.models.param import Param, ParamsDict
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils import timezone
from airflow.utils.types import DagRunType
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
Expand All @@ -41,14 +42,20 @@ def test_null_param(self):
with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"):
p.resolve()
assert p.resolve(None) is None
assert p.dump()["value"] is None
assert not p.has_value

p = Param(None)
assert p.resolve() is None
assert p.resolve(None) is None
assert p.dump()["value"] is None
assert not p.has_value

p = Param(None, type="null")
assert p.resolve() is None
assert p.resolve(None) is None
assert p.dump()["value"] is None
assert not p.has_value
with pytest.raises(ParamValidationError):
p.resolve("test")

Expand Down Expand Up @@ -222,6 +229,30 @@ def test_dump(self):
assert dump["description"] == "world"
assert dump["schema"] == {"type": "string", "minLength": 2}

@pytest.mark.parametrize(
"param",
[
Param("my value", description="hello", schema={"type": "string"}),
Param("my value", description="hello"),
Param(None, description=None),
Param([True], type="array", items={"type": "boolean"}),
Param(),
],
)
def test_param_serialization(self, param: Param):
"""
Test to make sure that native Param objects can be correctly serialized
"""

serializer = BaseSerialization()
serialized_param = serializer.serialize(param)
restored_param: Param = serializer.deserialize(serialized_param)

assert restored_param.value == param.value
assert isinstance(restored_param, Param)
assert restored_param.description == param.description
assert restored_param.schema == param.schema


class TestParamsDict:
def test_params_dict(self):
Expand Down
7 changes: 4 additions & 3 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,20 +899,21 @@ def __init__(self, path: str):
Param("my value", description="hello"),
Param(None, description=None),
Param([True], type="array", items={"type": "boolean"}),
Param(),
],
)
def test_full_param_roundtrip(self, param):
def test_full_param_roundtrip(self, param: Param):
"""
Test to make sure that only native Param objects are being passed as dag or task params
"""

dag = DAG(dag_id="simple_dag", params={"my_param": param})
dag = DAG(dag_id="simple_dag", schedule=None, params={"my_param": param})
serialized_json = SerializedDAG.to_json(dag)
serialized = json.loads(serialized_json)
SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)

assert dag.params["my_param"] == param.value
assert dag.params.get_param("my_param").value == param.value
observed_param = dag.params.get_param("my_param")
assert isinstance(observed_param, Param)
assert observed_param.description == param.description
Expand Down

0 comments on commit 489ca14

Please sign in to comment.