Skip to content

Commit

Permalink
Refactor results and result records to improve circular import situat…
Browse files Browse the repository at this point in the history
…ion (#17031)
  • Loading branch information
cicdw authored and kevingrismore committed Feb 7, 2025
1 parent 4fffa8d commit ec28f78
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 254 deletions.
235 changes: 235 additions & 0 deletions src/prefect/_result_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from __future__ import annotations

import inspect
import uuid
from typing import (
TYPE_CHECKING,
Any,
Generic,
Optional,
TypeVar,
Union,
)
from uuid import UUID

from pydantic import (
BaseModel,
Field,
ValidationError,
model_validator,
)

import prefect
from prefect.exceptions import (
SerializationError,
)
from prefect.serializers import PickleSerializer, Serializer
from prefect.types import DateTime

if TYPE_CHECKING:
pass


ResultSerializer = Union[Serializer, str]
LITERAL_TYPES: set[type] = {type(None), bool, UUID}
R = TypeVar("R")


class ResultRecordMetadata(BaseModel):
"""
Metadata for a result record.
"""

storage_key: Optional[str] = Field(
default=None
) # optional for backwards compatibility
expiration: Optional[DateTime] = Field(default=None)
serializer: Serializer = Field(default_factory=PickleSerializer)
prefect_version: str = Field(default=prefect.__version__)
storage_block_id: Optional[uuid.UUID] = Field(default=None)

def dump_bytes(self) -> bytes:
"""
Serialize the metadata to bytes.
Returns:
bytes: the serialized metadata
"""
return self.model_dump_json(serialize_as_any=True).encode()

@classmethod
def load_bytes(cls, data: bytes) -> "ResultRecordMetadata":
"""
Deserialize metadata from bytes.
Args:
data: the serialized metadata
Returns:
ResultRecordMetadata: the deserialized metadata
"""
return cls.model_validate_json(data)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, ResultRecordMetadata):
return False
return (
self.storage_key == other.storage_key
and self.expiration == other.expiration
and self.serializer == other.serializer
and self.prefect_version == other.prefect_version
and self.storage_block_id == other.storage_block_id
)


class ResultRecord(BaseModel, Generic[R]):
"""
A record of a result.
"""

metadata: ResultRecordMetadata
result: R

@property
def expiration(self) -> DateTime | None:
return self.metadata.expiration

@property
def serializer(self) -> Serializer:
return self.metadata.serializer

def serialize_result(self) -> bytes:
try:
data = self.serializer.dumps(self.result)
except Exception as exc:
extra_info = (
'You can try a different serializer (e.g. result_serializer="json") '
"or disabling persistence (persist_result=False) for this flow or task."
)
# check if this is a known issue with cloudpickle and pydantic
# and add extra information to help the user recover

if (
isinstance(exc, TypeError)
and isinstance(self.result, BaseModel)
and str(exc).startswith("cannot pickle")
):
try:
from IPython.core.getipython import get_ipython

if get_ipython() is not None:
extra_info = inspect.cleandoc(
"""
This is a known issue in Pydantic that prevents
locally-defined (non-imported) models from being
serialized by cloudpickle in IPython/Jupyter
environments. Please see
https://github.com/pydantic/pydantic/issues/8232 for
more information. To fix the issue, either: (1) move
your Pydantic class definition to an importable
location, (2) use the JSON serializer for your flow
or task (`result_serializer="json"`), or (3)
disable result persistence for your flow or task
(`persist_result=False`).
"""
).replace("\n", " ")
except ImportError:
pass
raise SerializationError(
f"Failed to serialize object of type {type(self.result).__name__!r} with "
f"serializer {self.serializer.type!r}. {extra_info}"
) from exc

return data

@model_validator(mode="before")
@classmethod
def coerce_old_format(cls, value: dict[str, Any] | Any) -> dict[str, Any]:
if isinstance(value, dict):
if "data" in value:
value["result"] = value.pop("data")
if "metadata" not in value:
value["metadata"] = {}
if "expiration" in value:
value["metadata"]["expiration"] = value.pop("expiration")
if "serializer" in value:
value["metadata"]["serializer"] = value.pop("serializer")
if "prefect_version" in value:
value["metadata"]["prefect_version"] = value.pop("prefect_version")
return value

def serialize_metadata(self) -> bytes:
return self.metadata.dump_bytes()

def serialize(
self,
) -> bytes:
"""
Serialize the record to bytes.
Returns:
bytes: the serialized record
"""
return (
self.model_copy(update={"result": self.serialize_result()})
.model_dump_json(serialize_as_any=True)
.encode()
)

@classmethod
def deserialize(
cls, data: bytes, backup_serializer: Serializer | None = None
) -> "ResultRecord[R]":
"""
Deserialize a record from bytes.
Args:
data: the serialized record
backup_serializer: The serializer to use to deserialize the result record. Only
necessary if the provided data does not specify a serializer.
Returns:
ResultRecord: the deserialized record
"""
try:
instance = cls.model_validate_json(data)
except ValidationError:
if backup_serializer is None:
raise
else:
result = backup_serializer.loads(data)
return cls(
metadata=ResultRecordMetadata(serializer=backup_serializer),
result=result,
)
if isinstance(instance.result, bytes):
instance.result = instance.serializer.loads(instance.result)
elif isinstance(instance.result, str):
instance.result = instance.serializer.loads(instance.result.encode())
return instance

@classmethod
def deserialize_from_result_and_metadata(
cls, result: bytes, metadata: bytes
) -> "ResultRecord[R]":
"""
Deserialize a record from separate result and metadata bytes.
Args:
result: the result
metadata: the serialized metadata
Returns:
ResultRecord: the deserialized record
"""
result_record_metadata = ResultRecordMetadata.load_bytes(metadata)
return cls(
metadata=result_record_metadata,
result=result_record_metadata.serializer.loads(result),
)

def __eq__(self, other: Any | "ResultRecord[Any]") -> bool:
if not isinstance(other, ResultRecord):
return False
return self.metadata == other.metadata and self.result == other.result
2 changes: 1 addition & 1 deletion src/prefect/client/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from prefect.utilities.pydantic import get_class_fields_only

if TYPE_CHECKING:
from prefect.results import ResultRecordMetadata
from prefect._result_records import ResultRecordMetadata

R = TypeVar("R")

Expand Down
2 changes: 1 addition & 1 deletion src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
from prefect.utilities.pydantic import handle_secret_render

if TYPE_CHECKING:
from prefect._result_records import ResultRecordMetadata
from prefect.client.schemas.actions import StateCreate
from prefect.results import ResultRecordMetadata


R = TypeVar("R", default=Any)
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def hydrated_context(
# We need to rebuild the models because we might be hydrating in a remote
# environment where the models are not available.
# TODO: Remove this once we have fixed our circular imports and we don't need to rebuild models any more.
from prefect._result_records import ResultRecordMetadata
from prefect.flows import Flow
from prefect.results import ResultRecordMetadata
from prefect.tasks import Task

_types: dict[str, Any] = dict(
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/deployments/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import pendulum

import prefect
from prefect._result_records import ResultRecordMetadata
from prefect.client.schemas import FlowRun
from prefect.client.utilities import inject_client
from prefect.context import FlowRunContext, TaskRunContext
from prefect.logging import get_logger
from prefect.results import ResultRecordMetadata
from prefect.states import Pending, Scheduled
from prefect.tasks import Task
from prefect.telemetry.run_telemetry import (
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from prefect.tasks import task, Task
from prefect.context import tags
from prefect.utilities.annotations import unmapped, allow_failure
from prefect.results import ResultRecordMetadata
from prefect._result_records import ResultRecordMetadata
from prefect.flow_runs import pause_flow_run, resume_flow_run, suspend_flow_run
from prefect.client.orchestration import get_client
from prefect.client.cloud import get_cloud_client
Expand Down
Loading

0 comments on commit ec28f78

Please sign in to comment.