diff --git a/aws_lambda_powertools/utilities/idempotency/base.py b/aws_lambda_powertools/utilities/idempotency/base.py index 46aa5ef8962..a8d509b86eb 100644 --- a/aws_lambda_powertools/utilities/idempotency/base.py +++ b/aws_lambda_powertools/utilities/idempotency/base.py @@ -18,6 +18,12 @@ BasePersistenceLayer, DataRecord, ) +from aws_lambda_powertools.utilities.idempotency.serialization.base import ( + BaseIdempotencySerializer, +) +from aws_lambda_powertools.utilities.idempotency.serialization.no_op import ( + NoOpSerializer, +) MAX_RETRIES = 2 logger = logging.getLogger(__name__) @@ -51,6 +57,7 @@ def __init__( function_payload: Any, config: IdempotencyConfig, persistence_store: BasePersistenceLayer, + output_serializer: Optional[BaseIdempotencySerializer] = None, function_args: Optional[Tuple] = None, function_kwargs: Optional[Dict] = None, ): @@ -65,12 +72,16 @@ def __init__( Idempotency Configuration persistence_store : BasePersistenceLayer Instance of persistence layer to store idempotency records + output_serializer: Optional[BaseIdempotencySerializer] + Serializer to transform the data to and from a dictionary. + If not supplied, no serialization is done via the NoOpSerializer function_args: Optional[Tuple] Function arguments function_kwargs: Optional[Dict] Function keyword arguments """ self.function = function + self.output_serializer = output_serializer or NoOpSerializer() self.data = deepcopy(_prepare_data(function_payload)) self.fn_args = function_args self.fn_kwargs = function_kwargs @@ -170,7 +181,7 @@ def _get_idempotency_record(self) -> Optional[DataRecord]: return data_record - def _handle_for_status(self, data_record: DataRecord) -> Optional[Dict[Any, Any]]: + def _handle_for_status(self, data_record: DataRecord) -> Optional[Any]: """ Take appropriate action based on data_record's status @@ -180,8 +191,9 @@ def _handle_for_status(self, data_record: DataRecord) -> Optional[Dict[Any, Any] Returns ------- - Optional[Dict[Any, Any] + Optional[Any] Function's response previously used for this idempotency key, if it has successfully executed already. + In case an output serializer is configured, the response is deserialized. Raises ------ @@ -206,8 +218,10 @@ def _handle_for_status(self, data_record: DataRecord) -> Optional[Dict[Any, Any] f"Execution already in progress with idempotency key: " f"{self.persistence_store.event_key_jmespath}={data_record.idempotency_key}", ) - - return data_record.response_json_as_dict() + response_dict: Optional[dict] = data_record.response_json_as_dict() + if response_dict is not None: + return self.output_serializer.from_dict(response_dict) + return None def _get_function_response(self): try: @@ -226,7 +240,8 @@ def _get_function_response(self): else: try: - self.persistence_store.save_success(data=self.data, result=response) + serialized_response: dict = self.output_serializer.to_dict(response) if response else None + self.persistence_store.save_success(data=self.data, result=serialized_response) except Exception as save_exception: raise IdempotencyPersistenceLayerError( "Failed to update record state to success in idempotency store", diff --git a/aws_lambda_powertools/utilities/idempotency/exceptions.py b/aws_lambda_powertools/utilities/idempotency/exceptions.py index 67a8d6721b1..6e5930549c4 100644 --- a/aws_lambda_powertools/utilities/idempotency/exceptions.py +++ b/aws_lambda_powertools/utilities/idempotency/exceptions.py @@ -71,3 +71,15 @@ class IdempotencyKeyError(BaseError): """ Payload does not contain an idempotent key """ + + +class IdempotencyModelTypeError(BaseError): + """ + Model type does not match expected payload output + """ + + +class IdempotencyNoSerializationModelError(BaseError): + """ + No model was supplied to the serializer + """ diff --git a/aws_lambda_powertools/utilities/idempotency/idempotency.py b/aws_lambda_powertools/utilities/idempotency/idempotency.py index 76d353d205e..f38a860a6c7 100644 --- a/aws_lambda_powertools/utilities/idempotency/idempotency.py +++ b/aws_lambda_powertools/utilities/idempotency/idempotency.py @@ -4,7 +4,8 @@ import functools import logging import os -from typing import Any, Callable, Dict, Optional, cast +from inspect import isclass +from typing import Any, Callable, Dict, Optional, Type, Union, cast from aws_lambda_powertools.middleware_factory import lambda_handler_decorator from aws_lambda_powertools.shared import constants @@ -14,6 +15,10 @@ from aws_lambda_powertools.utilities.idempotency.persistence.base import ( BasePersistenceLayer, ) +from aws_lambda_powertools.utilities.idempotency.serialization.base import ( + BaseIdempotencyModelSerializer, + BaseIdempotencySerializer, +) from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -85,6 +90,7 @@ def idempotent_function( data_keyword_argument: str, persistence_store: BasePersistenceLayer, config: Optional[IdempotencyConfig] = None, + output_serializer: Optional[Union[BaseIdempotencySerializer, Type[BaseIdempotencyModelSerializer]]] = None, ) -> Any: """ Decorator to handle idempotency of any function @@ -99,6 +105,11 @@ def idempotent_function( Instance of BasePersistenceLayer to store data config: IdempotencyConfig Configuration + output_serializer: Optional[Union[BaseIdempotencySerializer, Type[BaseIdempotencyModelSerializer]]] + Serializer to transform the data to and from a dictionary. + If not supplied, no serialization is done via the NoOpSerializer. + In case a serializer of type inheriting BaseIdempotencyModelSerializer is given, + the serializer is derived from the function return type. Examples -------- @@ -124,9 +135,14 @@ def process_order(customer_id: str, order: dict, **kwargs): data_keyword_argument=data_keyword_argument, persistence_store=persistence_store, config=config, + output_serializer=output_serializer, ), ) + if isclass(output_serializer) and issubclass(output_serializer, BaseIdempotencyModelSerializer): + # instantiate an instance of the serializer class + output_serializer = output_serializer.instantiate(function.__annotations__.get("return", None)) + config = config or IdempotencyConfig() @functools.wraps(function) @@ -147,6 +163,7 @@ def decorate(*args, **kwargs): function_payload=payload, config=config, persistence_store=persistence_store, + output_serializer=output_serializer, function_args=args, function_kwargs=kwargs, ) diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/__init__.py b/aws_lambda_powertools/utilities/idempotency/serialization/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/base.py b/aws_lambda_powertools/utilities/idempotency/serialization/base.py new file mode 100644 index 00000000000..45317bd0315 --- /dev/null +++ b/aws_lambda_powertools/utilities/idempotency/serialization/base.py @@ -0,0 +1,47 @@ +""" +Serialization for supporting idempotency +""" +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class BaseIdempotencySerializer(ABC): + """ + Abstract Base Class for Idempotency serialization layer, supporting dict operations. + """ + + @abstractmethod + def to_dict(self, data: Any) -> Dict: + raise NotImplementedError("Implementation of to_dict is required") + + @abstractmethod + def from_dict(self, data: Dict) -> Any: + raise NotImplementedError("Implementation of from_dict is required") + + +class BaseIdempotencyModelSerializer(BaseIdempotencySerializer): + """ + Abstract Base Class for Idempotency serialization layer, for using a model as data object representation. + """ + + @classmethod + @abstractmethod + def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer: + """ + Creates an instance of a serializer based on a provided model type. + In case the model_type is unknown, None will be sent as `model_type`. + It's on the implementer to verify that: + - None is handled correctly + - A model type not matching the expected types is handled + + Parameters + ---------- + model_type: Any + The model type to instantiate the class for + + Returns + ------- + BaseIdempotencySerializer + Instance of the serializer class + """ + pass diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/custom_dict.py b/aws_lambda_powertools/utilities/idempotency/serialization/custom_dict.py new file mode 100644 index 00000000000..2af8bed08b0 --- /dev/null +++ b/aws_lambda_powertools/utilities/idempotency/serialization/custom_dict.py @@ -0,0 +1,23 @@ +from typing import Any, Callable, Dict + +from aws_lambda_powertools.utilities.idempotency.serialization.base import BaseIdempotencySerializer + + +class CustomDictSerializer(BaseIdempotencySerializer): + def __init__(self, to_dict: Callable[[Any], Dict], from_dict: Callable[[Dict], Any]): + """ + Parameters + ---------- + to_dict: Callable[[Any], Dict] + A function capable of transforming the saved data object representation into a dictionary + from_dict: Callable[[Dict], Any] + A function capable of transforming the saved dictionary into the original data object representation + """ + self.__to_dict: Callable[[Any], Dict] = to_dict + self.__from_dict: Callable[[Dict], Any] = from_dict + + def to_dict(self, data: Any) -> Dict: + return self.__to_dict(data) + + def from_dict(self, data: Dict) -> Any: + return self.__from_dict(data) diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py b/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py new file mode 100644 index 00000000000..dac77ed7345 --- /dev/null +++ b/aws_lambda_powertools/utilities/idempotency/serialization/dataclass.py @@ -0,0 +1,43 @@ +from dataclasses import asdict, is_dataclass +from typing import Any, Dict, Type + +from aws_lambda_powertools.utilities.idempotency.exceptions import ( + IdempotencyModelTypeError, + IdempotencyNoSerializationModelError, +) +from aws_lambda_powertools.utilities.idempotency.serialization.base import ( + BaseIdempotencyModelSerializer, + BaseIdempotencySerializer, +) + +DataClass = Any + + +class DataclassSerializer(BaseIdempotencyModelSerializer): + """ + A serializer class for transforming data between dataclass objects and dictionaries. + """ + + def __init__(self, model: Type[DataClass]): + """ + Parameters + ---------- + model: Type[DataClass] + A dataclass type to be used for serialization and deserialization + """ + self.__model: Type[DataClass] = model + + def to_dict(self, data: DataClass) -> Dict: + return asdict(data) + + def from_dict(self, data: Dict) -> DataClass: + return self.__model(**data) + + @classmethod + def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer: + if model_type is None: + raise IdempotencyNoSerializationModelError("No serialization model was supplied") + + if not is_dataclass(model_type): + raise IdempotencyModelTypeError("Model type is not inherited of dataclass type") + return cls(model=model_type) diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/no_op.py b/aws_lambda_powertools/utilities/idempotency/serialization/no_op.py new file mode 100644 index 00000000000..59185f704e7 --- /dev/null +++ b/aws_lambda_powertools/utilities/idempotency/serialization/no_op.py @@ -0,0 +1,18 @@ +from typing import Dict + +from aws_lambda_powertools.utilities.idempotency.serialization.base import BaseIdempotencySerializer + + +class NoOpSerializer(BaseIdempotencySerializer): + def __init__(self): + """ + Parameters + ---------- + Default serializer, does not transform data + """ + + def to_dict(self, data: Dict) -> Dict: + return data + + def from_dict(self, data: Dict) -> Dict: + return data diff --git a/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py b/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py new file mode 100644 index 00000000000..0c168233bff --- /dev/null +++ b/aws_lambda_powertools/utilities/idempotency/serialization/pydantic.py @@ -0,0 +1,47 @@ +from typing import Any, Dict, Type + +from pydantic import BaseModel + +from aws_lambda_powertools.utilities.idempotency.exceptions import ( + IdempotencyModelTypeError, + IdempotencyNoSerializationModelError, +) +from aws_lambda_powertools.utilities.idempotency.serialization.base import ( + BaseIdempotencyModelSerializer, + BaseIdempotencySerializer, +) + + +class PydanticSerializer(BaseIdempotencyModelSerializer): + """Pydantic serializer for idempotency models""" + + def __init__(self, model: Type[BaseModel]): + """ + Parameters + ---------- + model: Model + Pydantic model to be used for serialization + """ + self.__model: Type[BaseModel] = model + + def to_dict(self, data: BaseModel) -> Dict: + if callable(getattr(data, "model_dump", None)): + # Support for pydantic V2 + return data.model_dump() # type: ignore[unused-ignore,attr-defined] + return data.dict() + + def from_dict(self, data: Dict) -> BaseModel: + if callable(getattr(self.__model, "model_validate", None)): + # Support for pydantic V2 + return self.__model.model_validate(data) # type: ignore[unused-ignore,attr-defined] + return self.__model.parse_obj(data) + + @classmethod + def instantiate(cls, model_type: Any) -> BaseIdempotencySerializer: + if model_type is None: + raise IdempotencyNoSerializationModelError("No serialization model was supplied") + + if not issubclass(model_type, BaseModel): + raise IdempotencyModelTypeError("Model type is not inherited from pydantic BaseModel") + + return cls(model=model_type) diff --git a/docs/utilities/idempotency.md b/docs/utilities/idempotency.md index 6e5c47af6fc..3f55b34c25c 100644 --- a/docs/utilities/idempotency.md +++ b/docs/utilities/idempotency.md @@ -152,6 +152,45 @@ When using `idempotent_function`, you must tell us which keyword parameter in yo --8<-- "examples/idempotency/src/working_with_idempotent_function_pydantic.py" ``` +#### Output serialization + +The default return of the `idempotent_function` decorator is a JSON object, but you can customize the function's return type by utilizing the `output_serializer` parameter. The output serializer supports any JSON serializable data, **Python Dataclasses** and **Pydantic Models**. + +!!! info "When using the `output_serializer` parameter, the data will continue to be stored in DynamoDB as a JSON object." + +Working with Pydantic Models: + +=== "Explicitly passing the Pydantic model type" + + ```python hl_lines="6 24 25 32 35 44" + --8<-- "examples/idempotency/src/working_with_pydantic_explicitly_output_serializer.py" + ``` +=== "Deducing the Pydantic model type from the return type annotation" + + ```python hl_lines="6 24 25 32 36 45" + --8<-- "examples/idempotency/src/working_with_pydantic_deduced_output_serializer.py" + ``` + +Working with Python Dataclasses: + +=== "Explicitly passing the model type" + + ```python hl_lines="8 27-29 36 39 48" + --8<-- "examples/idempotency/src/working_with_dataclass_explicitly_output_serializer.py" + ``` + +=== "Deducing the model type from the return type annotation" + + ```python hl_lines="8 27-29 36 40 49" + --8<-- "examples/idempotency/src/working_with_dataclass_deduced_output_serializer.py" + ``` + +=== "Using A Custom Type (Dataclasses)" + + ```python hl_lines="9 33 37 41-44 51 54" + --8<-- "examples/idempotency/src/working_with_idempotent_function_custom_output_serializer.py" + ``` + #### Batch integration You can can easily integrate with [Batch utility](batch.md){target="_blank"} via context manager. This ensures that you process each record in an idempotent manner, and guard against a [Lambda timeout](#lambda-timeouts) idempotent situation. diff --git a/examples/idempotency/src/working_with_dataclass_deduced_output_serializer.py b/examples/idempotency/src/working_with_dataclass_deduced_output_serializer.py new file mode 100644 index 00000000000..3feb5153e34 --- /dev/null +++ b/examples/idempotency/src/working_with_dataclass_deduced_output_serializer.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass + +from aws_lambda_powertools.utilities.idempotency import ( + DynamoDBPersistenceLayer, + IdempotencyConfig, + idempotent_function, +) +from aws_lambda_powertools.utilities.idempotency.serialization.dataclass import DataclassSerializer +from aws_lambda_powertools.utilities.typing import LambdaContext + +dynamodb = DynamoDBPersistenceLayer(table_name="IdempotencyTable") +config = IdempotencyConfig(event_key_jmespath="order_id") # see Choosing a payload subset section + + +@dataclass +class OrderItem: + sku: str + description: str + + +@dataclass +class Order: + item: OrderItem + order_id: int + + +@dataclass +class OrderOutput: + order_id: int + + +@idempotent_function( + data_keyword_argument="order", + config=config, + persistence_store=dynamodb, + output_serializer=DataclassSerializer, +) +# order output is deduced from return type +def deduced_order_output_serializer(order: Order) -> OrderOutput: + return OrderOutput(order_id=order.order_id) + + +def lambda_handler(event: dict, context: LambdaContext): + config.register_lambda_context(context) # see Lambda timeouts section + order_item = OrderItem(sku="fake", description="sample") + order = Order(item=order_item, order_id=1) + + # `order` parameter must be called as a keyword argument to work + deduced_order_output_serializer(order=order) diff --git a/examples/idempotency/src/working_with_dataclass_explicitly_output_serializer.py b/examples/idempotency/src/working_with_dataclass_explicitly_output_serializer.py new file mode 100644 index 00000000000..95b65c570e7 --- /dev/null +++ b/examples/idempotency/src/working_with_dataclass_explicitly_output_serializer.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass + +from aws_lambda_powertools.utilities.idempotency import ( + DynamoDBPersistenceLayer, + IdempotencyConfig, + idempotent_function, +) +from aws_lambda_powertools.utilities.idempotency.serialization.dataclass import DataclassSerializer +from aws_lambda_powertools.utilities.typing import LambdaContext + +dynamodb = DynamoDBPersistenceLayer(table_name="IdempotencyTable") +config = IdempotencyConfig(event_key_jmespath="order_id") # see Choosing a payload subset section + + +@dataclass +class OrderItem: + sku: str + description: str + + +@dataclass +class Order: + item: OrderItem + order_id: int + + +@dataclass +class OrderOutput: + order_id: int + + +@idempotent_function( + data_keyword_argument="order", + config=config, + persistence_store=dynamodb, + output_serializer=DataclassSerializer(model=OrderOutput), +) +def explicit_order_output_serializer(order: Order): + return OrderOutput(order_id=order.order_id) + + +def lambda_handler(event: dict, context: LambdaContext): + config.register_lambda_context(context) # see Lambda timeouts section + order_item = OrderItem(sku="fake", description="sample") + order = Order(item=order_item, order_id=1) + + # `order` parameter must be called as a keyword argument to work + explicit_order_output_serializer(order=order) diff --git a/examples/idempotency/src/working_with_idempotent_function_custom_output_serializer.py b/examples/idempotency/src/working_with_idempotent_function_custom_output_serializer.py new file mode 100644 index 00000000000..f8ef30c7ab2 --- /dev/null +++ b/examples/idempotency/src/working_with_idempotent_function_custom_output_serializer.py @@ -0,0 +1,63 @@ +from dataclasses import asdict, dataclass +from typing import Any, Dict + +from aws_lambda_powertools.utilities.idempotency import ( + DynamoDBPersistenceLayer, + IdempotencyConfig, + idempotent_function, +) +from aws_lambda_powertools.utilities.idempotency.serialization.custom_dict import CustomDictSerializer +from aws_lambda_powertools.utilities.typing import LambdaContext + +dynamodb = DynamoDBPersistenceLayer(table_name="IdempotencyTable") +config = IdempotencyConfig(event_key_jmespath="order_id") # see Choosing a payload subset section + + +@dataclass +class OrderItem: + sku: str + description: str + + +@dataclass +class Order: + item: OrderItem + order_id: int + + +@dataclass +class OrderOutput: + order_id: int + + +def custom_to_dict(x: Any) -> Dict: + return asdict(x) + + +def custom_from_dict(x: Dict) -> Any: + return OrderOutput(**x) + + +order_output_serializer = CustomDictSerializer( + to_dict=custom_to_dict, + from_dict=custom_from_dict, +) + + +@idempotent_function( + data_keyword_argument="order", + config=config, + persistence_store=dynamodb, + output_serializer=order_output_serializer, +) +def process_order(order: Order) -> OrderOutput: + return OrderOutput(order_id=order.order_id) + + +def lambda_handler(event: dict, context: LambdaContext): + config.register_lambda_context(context) # see Lambda timeouts section + order_item = OrderItem(sku="fake", description="sample") + order = Order(item=order_item, order_id=1) + + # `order` parameter must be called as a keyword argument to work + process_order(order=order) diff --git a/examples/idempotency/src/working_with_pydantic_deduced_output_serializer.py b/examples/idempotency/src/working_with_pydantic_deduced_output_serializer.py new file mode 100644 index 00000000000..98b7ed52bf8 --- /dev/null +++ b/examples/idempotency/src/working_with_pydantic_deduced_output_serializer.py @@ -0,0 +1,45 @@ +from aws_lambda_powertools.utilities.idempotency import ( + DynamoDBPersistenceLayer, + IdempotencyConfig, + idempotent_function, +) +from aws_lambda_powertools.utilities.idempotency.serialization.pydantic import PydanticSerializer +from aws_lambda_powertools.utilities.parser import BaseModel +from aws_lambda_powertools.utilities.typing import LambdaContext + +dynamodb = DynamoDBPersistenceLayer(table_name="IdempotencyTable") +config = IdempotencyConfig(event_key_jmespath="order_id") # see Choosing a payload subset section + + +class OrderItem(BaseModel): + sku: str + description: str + + +class Order(BaseModel): + item: OrderItem + order_id: int + + +class OrderOutput(BaseModel): + order_id: int + + +@idempotent_function( + data_keyword_argument="order", + config=config, + persistence_store=dynamodb, + output_serializer=PydanticSerializer, +) +# order output is deduced from return type +def deduced_order_output_serializer(order: Order) -> OrderOutput: + return OrderOutput(order_id=order.order_id) + + +def lambda_handler(event: dict, context: LambdaContext): + config.register_lambda_context(context) # see Lambda timeouts section + order_item = OrderItem(sku="fake", description="sample") + order = Order(item=order_item, order_id=1) + + # `order` parameter must be called as a keyword argument to work + deduced_order_output_serializer(order=order) diff --git a/examples/idempotency/src/working_with_pydantic_explicitly_output_serializer.py b/examples/idempotency/src/working_with_pydantic_explicitly_output_serializer.py new file mode 100644 index 00000000000..6219e688e17 --- /dev/null +++ b/examples/idempotency/src/working_with_pydantic_explicitly_output_serializer.py @@ -0,0 +1,44 @@ +from aws_lambda_powertools.utilities.idempotency import ( + DynamoDBPersistenceLayer, + IdempotencyConfig, + idempotent_function, +) +from aws_lambda_powertools.utilities.idempotency.serialization.pydantic import PydanticSerializer +from aws_lambda_powertools.utilities.parser import BaseModel +from aws_lambda_powertools.utilities.typing import LambdaContext + +dynamodb = DynamoDBPersistenceLayer(table_name="IdempotencyTable") +config = IdempotencyConfig(event_key_jmespath="order_id") # see Choosing a payload subset section + + +class OrderItem(BaseModel): + sku: str + description: str + + +class Order(BaseModel): + item: OrderItem + order_id: int + + +class OrderOutput(BaseModel): + order_id: int + + +@idempotent_function( + data_keyword_argument="order", + config=config, + persistence_store=dynamodb, + output_serializer=PydanticSerializer(model=OrderOutput), +) +def explicit_order_output_serializer(order: Order): + return OrderOutput(order_id=order.order_id) + + +def lambda_handler(event: dict, context: LambdaContext): + config.register_lambda_context(context) # see Lambda timeouts section + order_item = OrderItem(sku="fake", description="sample") + order = Order(item=order_item, order_id=1) + + # `order` parameter must be called as a keyword argument to work + explicit_order_output_serializer(order=order) diff --git a/tests/functional/idempotency/test_idempotency.py b/tests/functional/idempotency/test_idempotency.py index 47b24744665..24fcd76b4d5 100644 --- a/tests/functional/idempotency/test_idempotency.py +++ b/tests/functional/idempotency/test_idempotency.py @@ -17,6 +17,8 @@ from aws_lambda_powertools.utilities.idempotency import ( DynamoDBPersistenceLayer, IdempotencyConfig, + idempotent, + idempotent_function, ) from aws_lambda_powertools.utilities.idempotency.base import ( MAX_RETRIES, @@ -28,17 +30,18 @@ IdempotencyInconsistentStateError, IdempotencyInvalidStatusError, IdempotencyKeyError, + IdempotencyModelTypeError, + IdempotencyNoSerializationModelError, IdempotencyPersistenceLayerError, IdempotencyValidationError, ) -from aws_lambda_powertools.utilities.idempotency.idempotency import ( - idempotent, - idempotent_function, -) from aws_lambda_powertools.utilities.idempotency.persistence.base import ( BasePersistenceLayer, DataRecord, ) +from aws_lambda_powertools.utilities.idempotency.serialization.custom_dict import CustomDictSerializer +from aws_lambda_powertools.utilities.idempotency.serialization.dataclass import DataclassSerializer +from aws_lambda_powertools.utilities.idempotency.serialization.pydantic import PydanticSerializer from aws_lambda_powertools.utilities.validation import envelopes, validator from tests.functional.idempotency.utils import ( build_idempotency_put_item_stub, @@ -1196,6 +1199,297 @@ def record_handler(record): assert result == expected_result +def test_idempotent_function_serialization_custom_dict(): + # GIVEN + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_custom_dict..record_handler#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + to_dict_called = False + from_dict_called = False + + def to_dict(data): + nonlocal to_dict_called + to_dict_called = True + return data + + def from_dict(data): + nonlocal from_dict_called + from_dict_called = True + return data + + expected_result = {"message": "Foo"} + output_serializer = CustomDictSerializer( + to_dict=to_dict, + from_dict=from_dict, + ) + + @idempotent_function( + persistence_store=persistence_layer, + data_keyword_argument="record", + config=config, + output_serializer=output_serializer, + ) + def record_handler(record): + return expected_result + + record_handler(record=mock_event) + assert to_dict_called + record_handler(record=mock_event) + assert from_dict_called + + +def test_idempotent_function_serialization_no_response(): + # GIVEN + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_no_response..record_handler#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + to_dict_called = False + from_dict_called = False + + def to_dict(data): + nonlocal to_dict_called + to_dict_called = True + return data + + def from_dict(data): + nonlocal from_dict_called + from_dict_called = True + return data + + output_serializer = CustomDictSerializer( + to_dict=to_dict, + from_dict=from_dict, + ) + + @idempotent_function( + persistence_store=persistence_layer, + data_keyword_argument="record", + config=config, + output_serializer=output_serializer, + ) + def record_handler(record): + return None + + record_handler(record=mock_event) + assert to_dict_called is False, "in case response is None, to_dict should not be called" + response = record_handler(record=mock_event) + assert response is None + assert from_dict_called is False, "in case response is None, from_dict should not be called" + + +@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"]) +def test_idempotent_function_serialization_pydantic(output_serializer_type: str): + # GIVEN + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + class PaymentInput(BaseModel): + customer_id: str + transaction_id: str + + class PaymentOutput(BaseModel): + customer_id: str + transaction_id: str + + if output_serializer_type == "explicit": + output_serializer = PydanticSerializer( + model=PaymentOutput, + ) + else: + output_serializer = PydanticSerializer + + @idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=output_serializer, + ) + def collect_payment(payment: PaymentInput) -> PaymentOutput: + return PaymentOutput(**payment.dict()) + + # WHEN + payment = PaymentInput(**mock_event) + first_call: PaymentOutput = collect_payment(payment=payment) + assert first_call.customer_id == payment.customer_id + assert first_call.transaction_id == payment.transaction_id + assert isinstance(first_call, PaymentOutput) + second_call: PaymentOutput = collect_payment(payment=payment) + assert isinstance(second_call, PaymentOutput) + assert second_call.customer_id == payment.customer_id + assert second_call.transaction_id == payment.transaction_id + + +def test_idempotent_function_serialization_pydantic_failure_no_return_type(): + # GIVEN + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic_failure_no_return_type..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + class PaymentInput(BaseModel): + customer_id: str + transaction_id: str + + class PaymentOutput(BaseModel): + customer_id: str + transaction_id: str + + idempotent_function_decorator = idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=PydanticSerializer, + ) + with pytest.raises(IdempotencyNoSerializationModelError, match="No serialization model was supplied"): + + @idempotent_function_decorator + def collect_payment(payment: PaymentInput): + return PaymentOutput(**payment.dict()) + + +def test_idempotent_function_serialization_pydantic_failure_bad_type(): + # GIVEN + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic_failure_no_return_type..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + class PaymentInput(BaseModel): + customer_id: str + transaction_id: str + + class PaymentOutput(BaseModel): + customer_id: str + transaction_id: str + + idempotent_function_decorator = idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=PydanticSerializer, + ) + with pytest.raises(IdempotencyModelTypeError, match="Model type is not inherited from pydantic BaseModel"): + + @idempotent_function_decorator + def collect_payment(payment: PaymentInput) -> dict: + return PaymentOutput(**payment.dict()) + + +@pytest.mark.parametrize("output_serializer_type", ["explicit", "deduced"]) +def test_idempotent_function_serialization_dataclass(output_serializer_type: str): + # GIVEN + dataclasses = get_dataclasses_lib() + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_dataclass..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + @dataclasses.dataclass + class PaymentInput: + customer_id: str + transaction_id: str + + @dataclasses.dataclass + class PaymentOutput: + customer_id: str + transaction_id: str + + if output_serializer_type == "explicit": + output_serializer = DataclassSerializer( + model=PaymentOutput, + ) + else: + output_serializer = DataclassSerializer + + @idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=output_serializer, + ) + def collect_payment(payment: PaymentInput) -> PaymentOutput: + return PaymentOutput(**dataclasses.asdict(payment)) + + # WHEN + payment = PaymentInput(**mock_event) + first_call: PaymentOutput = collect_payment(payment=payment) + assert first_call.customer_id == payment.customer_id + assert first_call.transaction_id == payment.transaction_id + assert isinstance(first_call, PaymentOutput) + second_call: PaymentOutput = collect_payment(payment=payment) + assert isinstance(second_call, PaymentOutput) + assert second_call.customer_id == payment.customer_id + assert second_call.transaction_id == payment.transaction_id + + +def test_idempotent_function_serialization_dataclass_failure_no_return_type(): + # GIVEN + dataclasses = get_dataclasses_lib() + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic_failure_no_return_type..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + @dataclasses.dataclass + class PaymentInput: + customer_id: str + transaction_id: str + + @dataclasses.dataclass + class PaymentOutput: + customer_id: str + transaction_id: str + + idempotent_function_decorator = idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=DataclassSerializer, + ) + with pytest.raises(IdempotencyNoSerializationModelError, match="No serialization model was supplied"): + + @idempotent_function_decorator + def collect_payment(payment: PaymentInput): + return PaymentOutput(**payment.dict()) + + +def test_idempotent_function_serialization_dataclass_failure_bad_type(): + # GIVEN + dataclasses = get_dataclasses_lib() + config = IdempotencyConfig(use_local_cache=True) + mock_event = {"customer_id": "fake", "transaction_id": "fake-id"} + idempotency_key = f"{TESTS_MODULE_PREFIX}.test_idempotent_function_serialization_pydantic_failure_no_return_type..collect_payment#{hash_idempotency_key(mock_event)}" # noqa E501 + persistence_layer = MockPersistenceLayer(expected_idempotency_key=idempotency_key) + + @dataclasses.dataclass + class PaymentInput: + customer_id: str + transaction_id: str + + @dataclasses.dataclass + class PaymentOutput: + customer_id: str + transaction_id: str + + idempotent_function_decorator = idempotent_function( + data_keyword_argument="payment", + persistence_store=persistence_layer, + config=config, + output_serializer=PydanticSerializer, + ) + with pytest.raises(IdempotencyModelTypeError, match="Model type is not inherited from pydantic BaseModel"): + + @idempotent_function_decorator + def collect_payment(payment: PaymentInput) -> dict: + return PaymentOutput(**payment.dict()) + + def test_idempotent_function_arbitrary_args_kwargs(): # Scenario to validate we can use idempotent_function with a function # with an arbitrary number of args and kwargs