From ec28f7820c899d8fb70960891883c7176e46b130 Mon Sep 17 00:00:00 2001 From: Chris White Date: Fri, 7 Feb 2025 10:13:15 -0800 Subject: [PATCH] Refactor results and result records to improve circular import situation (#17031) --- src/prefect/_result_records.py | 235 +++++++++++++++ src/prefect/client/schemas/actions.py | 2 +- src/prefect/client/schemas/objects.py | 2 +- src/prefect/context.py | 2 +- src/prefect/deployments/flow_runs.py | 2 +- src/prefect/main.py | 2 +- src/prefect/results.py | 268 ++---------------- src/prefect/states.py | 11 +- tests/results/test_result_record.py | 7 +- tests/results/test_result_store.py | 12 +- .../server/orchestration/test_core_policy.py | 2 +- 11 files changed, 291 insertions(+), 254 deletions(-) create mode 100644 src/prefect/_result_records.py diff --git a/src/prefect/_result_records.py b/src/prefect/_result_records.py new file mode 100644 index 0000000000000..dd538e88b2a6d --- /dev/null +++ b/src/prefect/_result_records.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import inspect +import uuid +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Optional, + TypeVar, + Union, +) +from uuid import UUID + +from pydantic import ( + BaseModel, + Field, + ValidationError, + model_validator, +) + +import prefect +from prefect.exceptions import ( + SerializationError, +) +from prefect.serializers import PickleSerializer, Serializer +from prefect.types import DateTime + +if TYPE_CHECKING: + pass + + +ResultSerializer = Union[Serializer, str] +LITERAL_TYPES: set[type] = {type(None), bool, UUID} +R = TypeVar("R") + + +class ResultRecordMetadata(BaseModel): + """ + Metadata for a result record. + """ + + storage_key: Optional[str] = Field( + default=None + ) # optional for backwards compatibility + expiration: Optional[DateTime] = Field(default=None) + serializer: Serializer = Field(default_factory=PickleSerializer) + prefect_version: str = Field(default=prefect.__version__) + storage_block_id: Optional[uuid.UUID] = Field(default=None) + + def dump_bytes(self) -> bytes: + """ + Serialize the metadata to bytes. + + Returns: + bytes: the serialized metadata + """ + return self.model_dump_json(serialize_as_any=True).encode() + + @classmethod + def load_bytes(cls, data: bytes) -> "ResultRecordMetadata": + """ + Deserialize metadata from bytes. + + Args: + data: the serialized metadata + + Returns: + ResultRecordMetadata: the deserialized metadata + """ + return cls.model_validate_json(data) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, ResultRecordMetadata): + return False + return ( + self.storage_key == other.storage_key + and self.expiration == other.expiration + and self.serializer == other.serializer + and self.prefect_version == other.prefect_version + and self.storage_block_id == other.storage_block_id + ) + + +class ResultRecord(BaseModel, Generic[R]): + """ + A record of a result. + """ + + metadata: ResultRecordMetadata + result: R + + @property + def expiration(self) -> DateTime | None: + return self.metadata.expiration + + @property + def serializer(self) -> Serializer: + return self.metadata.serializer + + def serialize_result(self) -> bytes: + try: + data = self.serializer.dumps(self.result) + except Exception as exc: + extra_info = ( + 'You can try a different serializer (e.g. result_serializer="json") ' + "or disabling persistence (persist_result=False) for this flow or task." + ) + # check if this is a known issue with cloudpickle and pydantic + # and add extra information to help the user recover + + if ( + isinstance(exc, TypeError) + and isinstance(self.result, BaseModel) + and str(exc).startswith("cannot pickle") + ): + try: + from IPython.core.getipython import get_ipython + + if get_ipython() is not None: + extra_info = inspect.cleandoc( + """ + This is a known issue in Pydantic that prevents + locally-defined (non-imported) models from being + serialized by cloudpickle in IPython/Jupyter + environments. Please see + https://github.com/pydantic/pydantic/issues/8232 for + more information. To fix the issue, either: (1) move + your Pydantic class definition to an importable + location, (2) use the JSON serializer for your flow + or task (`result_serializer="json"`), or (3) + disable result persistence for your flow or task + (`persist_result=False`). + """ + ).replace("\n", " ") + except ImportError: + pass + raise SerializationError( + f"Failed to serialize object of type {type(self.result).__name__!r} with " + f"serializer {self.serializer.type!r}. {extra_info}" + ) from exc + + return data + + @model_validator(mode="before") + @classmethod + def coerce_old_format(cls, value: dict[str, Any] | Any) -> dict[str, Any]: + if isinstance(value, dict): + if "data" in value: + value["result"] = value.pop("data") + if "metadata" not in value: + value["metadata"] = {} + if "expiration" in value: + value["metadata"]["expiration"] = value.pop("expiration") + if "serializer" in value: + value["metadata"]["serializer"] = value.pop("serializer") + if "prefect_version" in value: + value["metadata"]["prefect_version"] = value.pop("prefect_version") + return value + + def serialize_metadata(self) -> bytes: + return self.metadata.dump_bytes() + + def serialize( + self, + ) -> bytes: + """ + Serialize the record to bytes. + + Returns: + bytes: the serialized record + + """ + return ( + self.model_copy(update={"result": self.serialize_result()}) + .model_dump_json(serialize_as_any=True) + .encode() + ) + + @classmethod + def deserialize( + cls, data: bytes, backup_serializer: Serializer | None = None + ) -> "ResultRecord[R]": + """ + Deserialize a record from bytes. + + Args: + data: the serialized record + backup_serializer: The serializer to use to deserialize the result record. Only + necessary if the provided data does not specify a serializer. + + Returns: + ResultRecord: the deserialized record + """ + try: + instance = cls.model_validate_json(data) + except ValidationError: + if backup_serializer is None: + raise + else: + result = backup_serializer.loads(data) + return cls( + metadata=ResultRecordMetadata(serializer=backup_serializer), + result=result, + ) + if isinstance(instance.result, bytes): + instance.result = instance.serializer.loads(instance.result) + elif isinstance(instance.result, str): + instance.result = instance.serializer.loads(instance.result.encode()) + return instance + + @classmethod + def deserialize_from_result_and_metadata( + cls, result: bytes, metadata: bytes + ) -> "ResultRecord[R]": + """ + Deserialize a record from separate result and metadata bytes. + + Args: + result: the result + metadata: the serialized metadata + + Returns: + ResultRecord: the deserialized record + """ + result_record_metadata = ResultRecordMetadata.load_bytes(metadata) + return cls( + metadata=result_record_metadata, + result=result_record_metadata.serializer.loads(result), + ) + + def __eq__(self, other: Any | "ResultRecord[Any]") -> bool: + if not isinstance(other, ResultRecord): + return False + return self.metadata == other.metadata and self.result == other.result diff --git a/src/prefect/client/schemas/actions.py b/src/prefect/client/schemas/actions.py index a5219a535b71b..771718b9beb38 100644 --- a/src/prefect/client/schemas/actions.py +++ b/src/prefect/client/schemas/actions.py @@ -46,7 +46,7 @@ from prefect.utilities.pydantic import get_class_fields_only if TYPE_CHECKING: - from prefect.results import ResultRecordMetadata + from prefect._result_records import ResultRecordMetadata R = TypeVar("R") diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index 2ba920865598b..3ab9868981cc9 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -65,8 +65,8 @@ from prefect.utilities.pydantic import handle_secret_render if TYPE_CHECKING: + from prefect._result_records import ResultRecordMetadata from prefect.client.schemas.actions import StateCreate - from prefect.results import ResultRecordMetadata R = TypeVar("R", default=Any) diff --git a/src/prefect/context.py b/src/prefect/context.py index 0a24544b46dee..8d1612c015c61 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -73,8 +73,8 @@ def hydrated_context( # We need to rebuild the models because we might be hydrating in a remote # environment where the models are not available. # TODO: Remove this once we have fixed our circular imports and we don't need to rebuild models any more. + from prefect._result_records import ResultRecordMetadata from prefect.flows import Flow - from prefect.results import ResultRecordMetadata from prefect.tasks import Task _types: dict[str, Any] = dict( diff --git a/src/prefect/deployments/flow_runs.py b/src/prefect/deployments/flow_runs.py index 169539df574f8..1927933ea040c 100644 --- a/src/prefect/deployments/flow_runs.py +++ b/src/prefect/deployments/flow_runs.py @@ -6,11 +6,11 @@ import pendulum import prefect +from prefect._result_records import ResultRecordMetadata from prefect.client.schemas import FlowRun from prefect.client.utilities import inject_client from prefect.context import FlowRunContext, TaskRunContext from prefect.logging import get_logger -from prefect.results import ResultRecordMetadata from prefect.states import Pending, Scheduled from prefect.tasks import Task from prefect.telemetry.run_telemetry import ( diff --git a/src/prefect/main.py b/src/prefect/main.py index 1e9e5421998fc..1fe7a64a9ebd9 100644 --- a/src/prefect/main.py +++ b/src/prefect/main.py @@ -8,7 +8,7 @@ from prefect.tasks import task, Task from prefect.context import tags from prefect.utilities.annotations import unmapped, allow_failure -from prefect.results import ResultRecordMetadata +from prefect._result_records import ResultRecordMetadata from prefect.flow_runs import pause_flow_run, resume_flow_run, suspend_flow_run from prefect.client.orchestration import get_client from prefect.client.cloud import get_cloud_client diff --git a/src/prefect/results.py b/src/prefect/results.py index 52f14c50d2438..5891dafe2f1dd 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -14,7 +14,6 @@ Any, Callable, ClassVar, - Generic, Optional, TypeVar, Union, @@ -28,22 +27,16 @@ Discriminator, Field, Tag, - ValidationError, - model_validator, ) from typing_extensions import ParamSpec, Self import prefect -from prefect._experimental.lineage import ( - emit_result_read_event, - emit_result_write_event, -) from prefect._internal.compatibility.async_dispatch import async_dispatch +from prefect._result_records import R, ResultRecord, ResultRecordMetadata from prefect.blocks.core import Block from prefect.exceptions import ( ConfigurationError, MissingContextError, - SerializationError, ) from prefect.filesystems import ( LocalFileSystem, @@ -52,7 +45,7 @@ ) from prefect.locking.protocol import LockManager from prefect.logging import get_logger -from prefect.serializers import PickleSerializer, Serializer +from prefect.serializers import Serializer from prefect.settings.context import get_current_settings from prefect.types import DateTime from prefect.utilities.annotations import NotSet @@ -76,7 +69,6 @@ def DEFAULT_STORAGE_KEY_FN() -> str: logger: "logging.Logger" = get_logger("results") P = ParamSpec("P") -R = TypeVar("R") _default_storages: dict[tuple[str, str], WritableFileSystem] = {} @@ -372,6 +364,31 @@ def result_storage_block_id(self) -> UUID | None: return None return getattr(self.result_storage, "_block_document_id", None) + @classmethod + async def _from_metadata(cls, metadata: ResultRecordMetadata) -> "ResultRecord[R]": + """ + Create a result record from metadata. + + Will use the result record metadata to fetch data via a result store. + + Args: + metadata: The metadata to create the result record from. + + Returns: + ResultRecord: The result record. + """ + if metadata.storage_block_id is None: + storage_block = None + else: + storage_block = await aresolve_result_storage(metadata.storage_block_id) + store = cls(result_storage=storage_block, serializer=metadata.serializer) + if metadata.storage_key is None: + raise ValueError( + "storage_key is required to hydrate a result record from metadata" + ) + result = await store.aread(metadata.storage_key) + return result + @sync_compatible async def update_for_flow(self, flow: "Flow[..., Any]") -> Self: """ @@ -554,6 +571,8 @@ async def _read(self, key: str, holder: str) -> "ResultRecord[Any]": A result record. """ + from prefect._experimental.lineage import emit_result_read_event + if self.lock_manager is not None and not self.is_lock_holder(key, holder): await self.await_for_lock(key) @@ -743,6 +762,8 @@ async def _persist_result_record( "Storage key is required on result record" ) + from prefect._experimental.lineage import emit_result_write_event + key = result_record.metadata.storage_key if result_record.metadata.storage_block_id is None: basepath = ( @@ -1007,230 +1028,3 @@ def get_result_store() -> ResultStore: else: result_store = run_context.result_store return result_store - - -class ResultRecordMetadata(BaseModel): - """ - Metadata for a result record. - """ - - storage_key: Optional[str] = Field( - default=None - ) # optional for backwards compatibility - expiration: Optional[DateTime] = Field(default=None) - serializer: Serializer = Field(default_factory=PickleSerializer) - prefect_version: str = Field(default=prefect.__version__) - storage_block_id: Optional[uuid.UUID] = Field(default=None) - - def dump_bytes(self) -> bytes: - """ - Serialize the metadata to bytes. - - Returns: - bytes: the serialized metadata - """ - return self.model_dump_json(serialize_as_any=True).encode() - - @classmethod - def load_bytes(cls, data: bytes) -> "ResultRecordMetadata": - """ - Deserialize metadata from bytes. - - Args: - data: the serialized metadata - - Returns: - ResultRecordMetadata: the deserialized metadata - """ - return cls.model_validate_json(data) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, ResultRecordMetadata): - return False - return ( - self.storage_key == other.storage_key - and self.expiration == other.expiration - and self.serializer == other.serializer - and self.prefect_version == other.prefect_version - and self.storage_block_id == other.storage_block_id - ) - - -class ResultRecord(BaseModel, Generic[R]): - """ - A record of a result. - """ - - metadata: ResultRecordMetadata - result: R - - @property - def expiration(self) -> DateTime | None: - return self.metadata.expiration - - @property - def serializer(self) -> Serializer: - return self.metadata.serializer - - def serialize_result(self) -> bytes: - try: - data = self.serializer.dumps(self.result) - except Exception as exc: - extra_info = ( - 'You can try a different serializer (e.g. result_serializer="json") ' - "or disabling persistence (persist_result=False) for this flow or task." - ) - # check if this is a known issue with cloudpickle and pydantic - # and add extra information to help the user recover - - if ( - isinstance(exc, TypeError) - and isinstance(self.result, BaseModel) - and str(exc).startswith("cannot pickle") - ): - try: - from IPython.core.getipython import get_ipython - - if get_ipython() is not None: - extra_info = inspect.cleandoc( - """ - This is a known issue in Pydantic that prevents - locally-defined (non-imported) models from being - serialized by cloudpickle in IPython/Jupyter - environments. Please see - https://github.com/pydantic/pydantic/issues/8232 for - more information. To fix the issue, either: (1) move - your Pydantic class definition to an importable - location, (2) use the JSON serializer for your flow - or task (`result_serializer="json"`), or (3) - disable result persistence for your flow or task - (`persist_result=False`). - """ - ).replace("\n", " ") - except ImportError: - pass - raise SerializationError( - f"Failed to serialize object of type {type(self.result).__name__!r} with " - f"serializer {self.serializer.type!r}. {extra_info}" - ) from exc - - return data - - @model_validator(mode="before") - @classmethod - def coerce_old_format(cls, value: dict[str, Any] | Any) -> dict[str, Any]: - if isinstance(value, dict): - if "data" in value: - value["result"] = value.pop("data") - if "metadata" not in value: - value["metadata"] = {} - if "expiration" in value: - value["metadata"]["expiration"] = value.pop("expiration") - if "serializer" in value: - value["metadata"]["serializer"] = value.pop("serializer") - if "prefect_version" in value: - value["metadata"]["prefect_version"] = value.pop("prefect_version") - return value - - @classmethod - async def _from_metadata(cls, metadata: ResultRecordMetadata) -> "ResultRecord[R]": - """ - Create a result record from metadata. - - Will use the result record metadata to fetch data via a result store. - - Args: - metadata: The metadata to create the result record from. - - Returns: - ResultRecord: The result record. - """ - if metadata.storage_block_id is None: - storage_block = None - else: - storage_block = await aresolve_result_storage(metadata.storage_block_id) - store = ResultStore( - result_storage=storage_block, serializer=metadata.serializer - ) - if metadata.storage_key is None: - raise ValueError( - "storage_key is required to hydrate a result record from metadata" - ) - result = await store.aread(metadata.storage_key) - return result - - def serialize_metadata(self) -> bytes: - return self.metadata.dump_bytes() - - def serialize( - self, - ) -> bytes: - """ - Serialize the record to bytes. - - Returns: - bytes: the serialized record - - """ - return ( - self.model_copy(update={"result": self.serialize_result()}) - .model_dump_json(serialize_as_any=True) - .encode() - ) - - @classmethod - def deserialize( - cls, data: bytes, backup_serializer: Serializer | None = None - ) -> "ResultRecord[R]": - """ - Deserialize a record from bytes. - - Args: - data: the serialized record - backup_serializer: The serializer to use to deserialize the result record. Only - necessary if the provided data does not specify a serializer. - - Returns: - ResultRecord: the deserialized record - """ - try: - instance = cls.model_validate_json(data) - except ValidationError: - if backup_serializer is None: - raise - else: - result = backup_serializer.loads(data) - return cls( - metadata=ResultRecordMetadata(serializer=backup_serializer), - result=result, - ) - if isinstance(instance.result, bytes): - instance.result = instance.serializer.loads(instance.result) - elif isinstance(instance.result, str): - instance.result = instance.serializer.loads(instance.result.encode()) - return instance - - @classmethod - def deserialize_from_result_and_metadata( - cls, result: bytes, metadata: bytes - ) -> "ResultRecord[R]": - """ - Deserialize a record from separate result and metadata bytes. - - Args: - result: the result - metadata: the serialized metadata - - Returns: - ResultRecord: the deserialized record - """ - result_record_metadata = ResultRecordMetadata.load_bytes(metadata) - return cls( - metadata=result_record_metadata, - result=result_record_metadata.serializer.loads(result), - ) - - def __eq__(self, other: Any | "ResultRecord[Any]") -> bool: - if not isinstance(other, ResultRecord): - return False - return self.metadata == other.metadata and self.result == other.result diff --git a/src/prefect/states.py b/src/prefect/states.py index 9df16927bd1da..8a13da4b88fc6 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -97,10 +97,10 @@ async def _get_state_result_data_with_retries( # grace here about missing results. The exception below could come in the form # of a missing file, a short read, or other types of errors depending on the # result storage backend. - from prefect.results import ( - ResultRecord, + from prefect._result_records import ( ResultRecordMetadata, ) + from prefect.results import ResultStore if retry_result_failure is False: max_attempts = 1 @@ -110,7 +110,7 @@ async def _get_state_result_data_with_retries( for i in range(1, max_attempts + 1): try: if isinstance(state.data, ResultRecordMetadata): - record = await ResultRecord._from_metadata(state.data) + record = await ResultStore._from_metadata(state.data) return record.result else: return await state.data.get() @@ -462,10 +462,11 @@ async def get_state_exception(state: State) -> BaseException: - `CrashedRun` if the state type is CRASHED. - `CancelledRun` if the state type is CANCELLED. """ - from prefect.results import ( + from prefect._result_records import ( ResultRecord, ResultRecordMetadata, ) + from prefect.results import ResultStore if state.is_failed(): wrapper = FailedRun @@ -482,7 +483,7 @@ async def get_state_exception(state: State) -> BaseException: if isinstance(state.data, ResultRecord): result = state.data.result elif isinstance(state.data, ResultRecordMetadata): - record = await ResultRecord._from_metadata(state.data) + record = await ResultStore._from_metadata(state.data) result = record.result elif state.data is None: result = None diff --git a/tests/results/test_result_record.py b/tests/results/test_result_record.py index 7dcd2098705da..b1037294dc4bd 100644 --- a/tests/results/test_result_record.py +++ b/tests/results/test_result_record.py @@ -1,8 +1,9 @@ import pytest from pydantic import ValidationError +from prefect._result_records import ResultRecord, ResultRecordMetadata from prefect.filesystems import NullFileSystem -from prefect.results import ResultRecord, ResultRecordMetadata, ResultStore +from prefect.results import ResultStore from prefect.serializers import JSONSerializer from prefect.settings import PREFECT_LOCAL_STORAGE_PATH @@ -37,7 +38,7 @@ async def test_from_metadata(self): result_record = store.create_result_record("The results are in...", "the-key") await store.apersist_result_record(result_record) - loaded = await ResultRecord._from_metadata(result_record.metadata) + loaded = await ResultStore._from_metadata(result_record.metadata) assert loaded.result == "The results are in..." async def test_from_metadata_with_raw_result(self): @@ -48,7 +49,7 @@ async def test_from_metadata_with_raw_result(self): result_record = store.create_result_record("The results are in...", "the-key") await store.apersist_result_record(result_record) - loaded = await ResultRecord._from_metadata(result_record.metadata) + loaded = await ResultStore._from_metadata(result_record.metadata) assert loaded.result == "The results are in..." # assert that the raw result was persisted without metadata diff --git a/tests/results/test_result_store.py b/tests/results/test_result_store.py index 78eb38d0f1e94..f7d28c5732fe0 100644 --- a/tests/results/test_result_store.py +++ b/tests/results/test_result_store.py @@ -892,7 +892,9 @@ async def test_result_store_emits_write_event( filesystem = LocalFileSystem(basepath=tmp_path) result_store = ResultStore(result_storage=filesystem) - with mock.patch("prefect.results.emit_result_write_event") as mock_emit: + with mock.patch( + "prefect._experimental.lineage.emit_result_write_event" + ) as mock_emit: await result_store.awrite(key="test", obj="test") resolved_key_path = result_store._resolved_key_path("test") mock_emit.assert_called_once_with(result_store, resolved_key_path) @@ -906,7 +908,9 @@ async def test_result_store_emits_read_event(self, tmp_path, enable_lineage_even # without the store's in-memory cache. other_result_store = ResultStore(result_storage=filesystem) - with mock.patch("prefect.results.emit_result_read_event") as mock_emit: + with mock.patch( + "prefect._experimental.lineage.emit_result_read_event" + ) as mock_emit: await other_result_store.aread(key="test") resolved_key_path = other_result_store._resolved_key_path("test") mock_emit.assert_called_once_with(other_result_store, resolved_key_path) @@ -919,7 +923,9 @@ async def test_result_store_emits_cached_read_event( ) await result_store.awrite(key="test", obj="test") - with mock.patch("prefect.results.emit_result_read_event") as mock_emit: + with mock.patch( + "prefect._experimental.lineage.emit_result_read_event" + ) as mock_emit: await result_store.aread(key="test") # cached read resolved_key_path = result_store._resolved_key_path("test") mock_emit.assert_called_once_with( diff --git a/tests/server/orchestration/test_core_policy.py b/tests/server/orchestration/test_core_policy.py index c53cab46d5c73..ad954135e2464 100644 --- a/tests/server/orchestration/test_core_policy.py +++ b/tests/server/orchestration/test_core_policy.py @@ -10,7 +10,7 @@ import pytest import sqlalchemy as sa -from prefect.results import ResultRecordMetadata +from prefect._result_records import ResultRecordMetadata from prefect.server import schemas from prefect.server.database import orm_models as orm from prefect.server.exceptions import ObjectNotFoundError