Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create generic SerializeAsOptional type for Pydantic #564

Merged
merged 31 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
813854a
Create annotated list type that serializes to `None` if empty
disrupted Dec 12, 2024
6e7448f
Expand test
disrupted Dec 12, 2024
903434e
Make it more generic
disrupted Dec 12, 2024
111948e
Improve schema
disrupted Dec 12, 2024
372c6be
Move to Pydantic utils
disrupted Dec 12, 2024
c5c2876
Refactor other optional collection types
disrupted Dec 12, 2024
f716bcb
Rename schema class
disrupted Dec 12, 2024
306019c
Apply to StreamsBootstrapValues
disrupted Dec 16, 2024
e74bcf4
Apply to StreamsBootstrapValues
disrupted Dec 16, 2024
7503e4b
Fix
disrupted Dec 16, 2024
ebf19a2
Skip serialization step for cleaner instantiation
disrupted Dec 16, 2024
2091375
Add failing test
disrupted Dec 16, 2024
7a1f4b0
Try refactor to include serializer function
disrupted Dec 16, 2024
5917f8c
Cosmetic
disrupted Dec 16, 2024
32de61b
Implement and explain workaround in test
disrupted Dec 16, 2024
bf2f6b6
Apply workaround
disrupted Dec 16, 2024
3ad6c9c
Link to upstream issue with potential solution for `exclude_none`
disrupted Dec 16, 2024
0c3ccb8
Validate `None` to default
disrupted Dec 16, 2024
e271eb2
Expand test
disrupted Dec 16, 2024
afd6ac2
Serialize StreamsBootstrapValues correctly
disrupted Dec 16, 2024
549697c
Add test for exclude_by_value
disrupted Dec 16, 2024
e55dfc9
Inherit from SerializeAsOptionalModel
disrupted Dec 16, 2024
0ba91ca
Extend snapshot test with affinity
disrupted Dec 16, 2024
ec8841f
Update snapshot
disrupted Dec 16, 2024
8bfddd7
Merge branch 'main' into refactor/pydantic-optional-list
disrupted Dec 17, 2024
1043311
Merge remote-tracking branch 'origin/main' into refactor/pydantic-opt…
disrupted Dec 17, 2024
92c9698
Apply SerializeAsOptional to streams-bootstrap v2
disrupted Dec 17, 2024
3020168
Fix allow optional resources requests and limits (#570)
disrupted Dec 17, 2024
b14738f
Bump version 8.3.1 → 8.3.2
bakdata-bot Dec 17, 2024
d37cdfb
Merge branch 'main' into refactor/pydantic-optional-list
disrupted Dec 17, 2024
286be69
Fix import
disrupted Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 53 additions & 53 deletions docs/docs/schema/defaults.json

Large diffs are not rendered by default.

80 changes: 40 additions & 40 deletions docs/docs/schema/pipeline.json

Large diffs are not rendered by default.

58 changes: 31 additions & 27 deletions kpops/components/common/kubernetes_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from pydantic import Field, model_validator

from kpops.utils.docstring import describe_attr
from kpops.utils.pydantic import CamelCaseConfigModel, DescConfigModel
from kpops.utils.pydantic import (
CamelCaseConfigModel,
DescConfigModel,
SerializeAsOptional,
)

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -98,11 +102,11 @@ class NodeSelectorTerm(DescConfigModel, CamelCaseConfigModel):
:param match_fields: A list of node selector requirements by node's fields.
"""

match_expressions: list[NodeSelectorRequirement] | None = Field(
default=None, description=describe_attr("match_expressions", __doc__)
match_expressions: SerializeAsOptional[list[NodeSelectorRequirement]] = Field(
default=[], description=describe_attr("match_expressions", __doc__)
)
match_fields: list[NodeSelectorRequirement] | None = Field(
default=None, description=describe_attr("match_fields", __doc__)
match_fields: SerializeAsOptional[list[NodeSelectorRequirement]] = Field(
default=[], description=describe_attr("match_fields", __doc__)
)


Expand Down Expand Up @@ -143,10 +147,10 @@ class NodeAffinity(DescConfigModel, CamelCaseConfigModel):
"required_during_scheduling_ignored_during_execution", __doc__
),
)
preferred_during_scheduling_ignored_during_execution: (
list[PreferredSchedulingTerm] | None
) = Field(
default=None,
preferred_during_scheduling_ignored_during_execution: SerializeAsOptional[
list[PreferredSchedulingTerm]
] = Field(
default=[],
description=describe_attr(
"preferred_during_scheduling_ignored_during_execution", __doc__
),
Expand Down Expand Up @@ -197,12 +201,12 @@ class LabelSelector(DescConfigModel, CamelCaseConfigModel):
:param match_expressions: matchExpressions is a list of label selector requirements. The requirements are ANDed.
"""

match_labels: dict[str, str] | None = Field(
default=None,
match_labels: SerializeAsOptional[dict[str, str]] = Field(
default={},
description=describe_attr("match_labels", __doc__),
)
match_expressions: list[LabelSelectorRequirement] | None = Field(
default=None,
match_expressions: SerializeAsOptional[list[LabelSelectorRequirement]] = Field(
default=[],
description=describe_attr("match_expressions", __doc__),
)

Expand All @@ -222,19 +226,19 @@ class PodAffinityTerm(DescConfigModel, CamelCaseConfigModel):
default=None,
description=describe_attr("label_selector", __doc__),
)
match_label_keys: list[str] | None = Field(
default=None,
match_label_keys: SerializeAsOptional[list[str]] = Field(
default=[],
description=describe_attr("match_label_keys", __doc__),
)
mismatch_label_keys: list[str] | None = Field(
default=None,
mismatch_label_keys: SerializeAsOptional[list[str]] = Field(
default=[],
description=describe_attr("mismatch_label_keys", __doc__),
)
topology_key: str = Field(
description=describe_attr("topology_key", __doc__),
)
namespaces: list[str] | None = Field(
default=None,
namespaces: SerializeAsOptional[list[str]] = Field(
default=[],
description=describe_attr("namespaces", __doc__),
)
namespace_selector: LabelSelector | None = Field(
Expand Down Expand Up @@ -265,18 +269,18 @@ class PodAffinity(DescConfigModel, CamelCaseConfigModel):
:param preferred_during_scheduling_ignored_during_execution: The scheduler will prefer to schedule pods to nodes that satisfy the affinity expressions specified by this field, but it may choose a node that violates one or more of the expressions. The node that is most preferred is the one with the greatest sum of weights, i.e. for each node that meets all of the scheduling requirements (resource request, requiredDuringScheduling affinity expressions, etc.), compute a sum by iterating through the elements of this field and adding weight to the sum if the node has pods which matches the corresponding podAffinityTerm; the node(s) with the highest sum are the most preferred.
"""

required_during_scheduling_ignored_during_execution: (
list[PodAffinityTerm] | None
) = Field(
default=None,
required_during_scheduling_ignored_during_execution: SerializeAsOptional[
list[PodAffinityTerm]
] = Field(
default=[],
description=describe_attr(
"required_during_scheduling_ignored_during_execution", __doc__
),
)
preferred_during_scheduling_ignored_during_execution: (
list[WeightedPodAffinityTerm] | None
) = Field(
default=None,
preferred_during_scheduling_ignored_during_execution: SerializeAsOptional[
list[WeightedPodAffinityTerm]
] = Field(
default=[],
description=describe_attr(
"preferred_during_scheduling_ignored_during_execution", __doc__
),
Expand Down
53 changes: 27 additions & 26 deletions kpops/components/streams_bootstrap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ImagePullPolicy,
ProtocolSchema,
Resources,
SerializeAsOptional,
ServiceType,
Toleration,
)
Expand Down Expand Up @@ -132,8 +133,8 @@ class StreamsBootstrapValues(HelmAppValues):
description=describe_attr("image_pull_policy", __doc__),
)

image_pull_secrets: list[dict[str, str]] | None = Field(
default=None,
image_pull_secrets: SerializeAsOptional[list[dict[str, str]]] = Field(
default=[],
description=describe_attr("image_pull_secret", __doc__),
)

Expand All @@ -146,8 +147,8 @@ class StreamsBootstrapValues(HelmAppValues):
description=describe_attr("resources", __doc__),
)

ports: list[PortConfig] | None = Field(
default=None,
ports: SerializeAsOptional[list[PortConfig]] = Field(
default=[],
description=describe_attr("ports", __doc__),
)

Expand All @@ -161,33 +162,33 @@ class StreamsBootstrapValues(HelmAppValues):
description=describe_attr("configuration_env_prefix", __doc__),
)

command_line: dict[str, str | bool | int] | None = Field(
default=None,
command_line: SerializeAsOptional[dict[str, str | bool | int]] = Field(
default={},
description=describe_attr("command_line", __doc__),
)

env: dict[str, str] | None = Field(
default=None,
env: SerializeAsOptional[dict[str, str]] = Field(
default={},
description=describe_attr("env", __doc__),
)

secrets: dict[str, str] | None = Field(
default=None,
secrets: SerializeAsOptional[dict[str, str]] = Field(
default={},
description=describe_attr("secrets", __doc__),
)

secret_refs: dict[str, Any] | None = Field(
default=None,
secret_refs: SerializeAsOptional[dict[str, Any]] = Field(
default={},
description=describe_attr("secret_refs", __doc__),
)

secret_files_refs: list[str] | None = Field(
default=None,
secret_files_refs: SerializeAsOptional[list[str]] = Field(
default=[],
description=describe_attr("secret_files_refs", __doc__),
)

files: dict[str, Any] | None = Field(
default=None,
files: SerializeAsOptional[dict[str, Any]] = Field(
default={},
description=describe_attr("files", __doc__),
)

Expand All @@ -196,23 +197,23 @@ class StreamsBootstrapValues(HelmAppValues):
description=describe_attr("java_options", __doc__),
)

pod_annotations: dict[str, str] | None = Field(
default=None,
pod_annotations: SerializeAsOptional[dict[str, str]] = Field(
default={},
description=describe_attr("pod_annotations", __doc__),
)

pod_labels: dict[str, str] | None = Field(
default=None,
pod_labels: SerializeAsOptional[dict[str, str]] = Field(
default={},
description=describe_attr("pod_labels", __doc__),
)

liveness_probe: dict[str, Any] | None = Field(
default=None,
liveness_probe: SerializeAsOptional[dict[str, Any]] = Field(
default={},
description=describe_attr("liveness_probe", __doc__),
)

readiness_probe: dict[str, Any] | None = Field(
default=None,
readiness_probe: SerializeAsOptional[dict[str, Any]] = Field(
default={},
description=describe_attr("readiness_probe", __doc__),
)

Expand All @@ -221,8 +222,8 @@ class StreamsBootstrapValues(HelmAppValues):
description=describe_attr("affinity", __doc__),
)

tolerations: list[Toleration] | None = Field(
default=None,
tolerations: SerializeAsOptional[list[Toleration]] = Field(
default=[],
description=describe_attr("tolerations", __doc__),
)

Expand Down
9 changes: 6 additions & 3 deletions kpops/components/streams_bootstrap/producer/producer_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ def is_cron_job(self) -> bool:
@computed_field
@cached_property
def _cleaner(self) -> ProducerAppCleaner:
return ProducerAppCleaner(
**self.model_dump(by_alias=True, exclude={"_cleaner", "from_", "to"})
)
kwargs = {
name: getattr(self, name)
for name in self.model_fields_set
if name not in {"_cleaner", "from_", "to"}
}
return ProducerAppCleaner.model_validate(kwargs)

@override
def apply_to_outputs(self, name: str, topic: TopicConfig) -> None:
Expand Down
17 changes: 9 additions & 8 deletions kpops/components/streams_bootstrap/streams/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from kpops.components.common.kubernetes_model import (
ImagePullPolicy,
Resources,
SerializeAsOptional,
)
from kpops.components.common.topic import KafkaTopic, KafkaTopicStr
from kpops.components.streams_bootstrap.model import (
Expand Down Expand Up @@ -191,16 +192,16 @@ class StreamsAppAutoScaling(CamelCaseConfigModel, DescConfigModel):
title="Idle replica count",
description=describe_attr("idle_replicas", __doc__),
)
internal_topics: list[str] | None = Field(
default=None,
internal_topics: SerializeAsOptional[list[str]] = Field(
default=[],
description=describe_attr("internal_topics", __doc__),
)
topics: list[str] | None = Field(
default=None,
topics: SerializeAsOptional[list[str]] = Field(
default=[],
description=describe_attr("topics", __doc__),
)
additional_triggers: list[str] | None = Field(
default=None,
additional_triggers: SerializeAsOptional[list[str]] = Field(
default=[],
description=describe_attr("additional_triggers", __doc__),
)
model_config = ConfigDict(extra="allow")
Expand Down Expand Up @@ -289,8 +290,8 @@ class JMXConfig(CamelCaseConfigModel, DescConfigModel):
description=describe_attr("port", __doc__),
)

metric_rules: list[str] | None = Field(
default=None,
metric_rules: SerializeAsOptional[list[str]] = Field(
default=[],
description=describe_attr("metric_rules", __doc__),
)

Expand Down
9 changes: 6 additions & 3 deletions kpops/components/streams_bootstrap/streams/streams_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,12 @@ class StreamsApp(StreamsBootstrap):
@computed_field
@cached_property
def _cleaner(self) -> StreamsAppCleaner:
return StreamsAppCleaner(
**self.model_dump(by_alias=True, exclude={"_cleaner", "from_", "to"})
)
kwargs = {
name: getattr(self, name)
for name in self.model_fields_set
if name not in {"_cleaner", "from_", "to"}
}
return StreamsAppCleaner.model_validate(kwargs)

@property
@override
Expand Down
44 changes: 42 additions & 2 deletions kpops/utils/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import json
import logging
from pathlib import Path
from typing import Any
from typing import Annotated, Any

import humps
from pydantic import BaseModel, ConfigDict, Field
from pydantic import (
BaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
SerializationInfo,
SerializerFunctionWrapHandler,
WrapSerializer,
)
from pydantic.fields import FieldInfo
from pydantic_core import core_schema
from pydantic_settings import PydanticBaseSettingsSource
from typing_extensions import TypeVar, override

Expand Down Expand Up @@ -224,3 +233,34 @@ def __call__(self) -> dict[str, Any]:
if field_value is not None:
d[field_key] = field_value
return d


_T = TypeVar("_T")


def serialize_to_optional(
value: _T,
default_serialize_handler: SerializerFunctionWrapHandler,
info: SerializationInfo,
) -> _T | None:
result = default_serialize_handler(value)
return result or None


class WrapNullableSchema:
def __get_pydantic_core_schema__(
self,
source: type[Any],
handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
schema = handler(source)
# wrap generated schema in nullable
return core_schema.NullableSchema(type="nullable", schema=schema)


SerializeAsOptional = Annotated[
_T,
WrapSerializer(serialize_to_optional),
WrapNullableSchema(),
"Optional that is serialized to None if falsy",
]
Loading
Loading