From 3fb448407821fe67654c1aaf63ed199a92c3c0cb Mon Sep 17 00:00:00 2001 From: Nico Tonnhofer Date: Thu, 11 Apr 2024 19:32:15 +0000 Subject: [PATCH 1/5] fix(typing): ensure return type is a T when default_value is set --- .../api_gateway_authorizer_event.py | 11 ++++ .../data_classes/appsync_resolver_event.py | 18 ++++++- .../utilities/data_classes/common.py | 18 ++++++- .../utilities/data_classes/kafka_event.py | 20 ++++++- .../utilities/data_classes/s3_object_event.py | 18 ++++++- .../data_classes/shared_functions.py | 54 +++++++++++++++---- .../utilities/data_classes/vpc_lattice.py | 6 +++ 7 files changed, 129 insertions(+), 16 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py index 0b1aec43e8a..073789eabe6 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py @@ -283,6 +283,17 @@ def path_parameters(self) -> Optional[Dict[str, str]]: def stage_variables(self) -> Optional[Dict[str, str]]: return self.get("stageVariables") + @overload + def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: Optional[bool] = False, + ) -> Optional[str]: ... + def get_header_value( self, name: str, diff --git a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py index 14973009fb9..edb590e7770 100644 --- a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py +++ b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, overload from aws_lambda_powertools.utilities.data_classes.common import DictWrapper from aws_lambda_powertools.utilities.data_classes.shared_functions import ( @@ -214,6 +214,22 @@ def stash(self) -> Optional[dict]: a pipeline resolver.""" return self.get("stash") + @overload + def get_header_value( + self, + name: str, + default_value: str, + case_sensitive: Optional[bool] = False, + ) -> str: ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: Optional[bool] = False, + ) -> Optional[str]: ... + def get_header_value( self, name: str, diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index ffca15cc318..061e87ae1ec 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -2,7 +2,7 @@ import json from collections.abc import Mapping from functools import cached_property -from typing import Any, Callable, Dict, Iterator, List, Optional, overload +from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, overload from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer from aws_lambda_powertools.utilities.data_classes.shared_functions import ( @@ -11,6 +11,8 @@ get_query_string_value, ) +T = TypeVar("T") + class DictWrapper(Mapping): """Provides a single read only access to a wrapper dict""" @@ -86,7 +88,13 @@ def _str_helper(self) -> Dict[str, Any]: def _properties(self) -> List[str]: return [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)] - def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + @overload + def get(self, key: str, default: T) -> T: ... + + @overload + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: ... + + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: return self._data.get(key, default) @property @@ -172,6 +180,12 @@ def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" return self["httpMethod"] + @overload + def get_query_string_value(self, name: str, default_value: str) -> str: ... + + @overload + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ... + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: """Get query string value by name diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index d3d1425f0f2..f30c226daf3 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -1,6 +1,6 @@ import base64 from functools import cached_property -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional, overload from aws_lambda_powertools.utilities.data_classes.common import DictWrapper from aws_lambda_powertools.utilities.data_classes.shared_functions import ( @@ -69,10 +69,26 @@ def decoded_headers(self) -> Dict[str, bytes]: """Decodes the headers as a single dictionary.""" return {k: bytes(v) for chunk in self.headers for k, v in chunk.items()} + @overload def get_header_value( self, name: str, - default_value: Optional[Any] = None, + default_value: str, + case_sensitive: bool = False, + ) -> str: ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: bool = False, + ) -> Optional[str]: ... + + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, case_sensitive: bool = True, ) -> Optional[str]: """Get a decoded header value by name.""" diff --git a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py index a7953c32c59..4338a2f61e7 100644 --- a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py +++ b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, overload from aws_lambda_powertools.utilities.data_classes.common import DictWrapper from aws_lambda_powertools.utilities.data_classes.shared_functions import ( @@ -73,6 +73,22 @@ def headers(self) -> Dict[str, str]: The case of the original headers is retained in this map.""" return self["headers"] + @overload + def get_header_value( + self, + name: str, + default_value: str, + case_sensitive: Optional[bool] = False, + ) -> str: ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: Optional[bool] = False, + ) -> Optional[str]: ... + def get_header_value( self, name: str, diff --git a/aws_lambda_powertools/utilities/data_classes/shared_functions.py b/aws_lambda_powertools/utilities/data_classes/shared_functions.py index 43a3aad281b..e8c95d31760 100644 --- a/aws_lambda_powertools/utilities/data_classes/shared_functions.py +++ b/aws_lambda_powertools/utilities/data_classes/shared_functions.py @@ -1,7 +1,7 @@ from __future__ import annotations import base64 -from typing import Any, Dict +from typing import Any, Dict, List, Optional, overload def base64_decode(value: str) -> str: @@ -21,12 +21,30 @@ def base64_decode(value: str) -> str: return base64.b64decode(value).decode("UTF-8") +@overload def get_header_value( headers: dict[str, Any], name: str, - default_value: str | None, - case_sensitive: bool | None, -) -> str | None: + default_value: str, + case_sensitive: bool, +) -> str: ... + + +@overload +def get_header_value( + headers: dict[str, Any], + name: str, + default_value: Optional[str], + case_sensitive: bool, +) -> Optional[str]: ... + + +def get_header_value( + headers: dict[str, Any], + name: str, + default_value: Optional[str], + case_sensitive: bool, +) -> Optional[str]: """ Get the value of a header by its name. @@ -62,11 +80,27 @@ def get_header_value( ) +@overload +def get_query_string_value( + query_string_parameters: Dict[str, str] | None, + name: str, + default_value: str, +) -> str: ... + + +@overload +def get_query_string_value( + query_string_parameters: Dict[str, str] | None, + name: str, + default_value: Optional[str] = None, +) -> Optional[str]: ... + + def get_query_string_value( query_string_parameters: Dict[str, str] | None, name: str, - default_value: str | None = None, -) -> str | None: + default_value: Optional[str] = None, +) -> Optional[str]: """ Retrieves the value of a query string parameter specified by the given name. @@ -87,10 +121,10 @@ def get_query_string_value( def get_multi_value_query_string_values( - multi_value_query_string_parameters: Dict[str, list[str]] | None, + multi_value_query_string_parameters: Dict[str, List[str]] | None, name: str, - default_values: list[str] | None = None, -) -> list[str]: + default_values: Optional[List[str]] = None, +) -> List[str]: """ Retrieves the values of a multi-value string parameters specified by the given name. @@ -98,7 +132,7 @@ def get_multi_value_query_string_values( ---------- name: str The name of the query string parameter to retrieve. - default_value: list[str], optional + default_value: List[str], optional The default value to return if the parameter is not found. Defaults to None. Returns diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index f997d4b3f04..e2612e0334f 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -47,6 +47,12 @@ def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" return self["method"] + @overload + def get_query_string_value(self, name: str, default_value: str) -> str: ... + + @overload + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ... + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: """Get query string value by name From d5f9efe9c773df4b2d075cbcc07f0a8d0c5910fa Mon Sep 17 00:00:00 2001 From: Nico Tonnhofer Date: Fri, 12 Apr 2024 10:55:19 +0000 Subject: [PATCH 2/5] fix: revert overload get in Mapping --- aws_lambda_powertools/utilities/data_classes/common.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 061e87ae1ec..b0b67ac840d 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -88,13 +88,7 @@ def _str_helper(self) -> Dict[str, Any]: def _properties(self) -> List[str]: return [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)] - @overload - def get(self, key: str, default: T) -> T: ... - - @overload - def get(self, key: str, default: Optional[T] = None) -> Optional[T]: ... - - def get(self, key: str, default: Optional[T] = None) -> Optional[T]: + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: return self._data.get(key, default) @property From 4cbc87ed2301613eeebf42fafa0fbe9496ce4339 Mon Sep 17 00:00:00 2001 From: Nico Tonnhofer Date: Fri, 12 Apr 2024 11:04:34 +0000 Subject: [PATCH 3/5] fix: set defaults to None --- .../data_classes/shared_functions.py | 42 +++++++++---------- .../src/custom_models.py | 4 +- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/shared_functions.py b/aws_lambda_powertools/utilities/data_classes/shared_functions.py index e8c95d31760..216ff08fe2d 100644 --- a/aws_lambda_powertools/utilities/data_classes/shared_functions.py +++ b/aws_lambda_powertools/utilities/data_classes/shared_functions.py @@ -1,7 +1,7 @@ from __future__ import annotations import base64 -from typing import Any, Dict, List, Optional, overload +from typing import Any, overload def base64_decode(value: str) -> str: @@ -26,7 +26,7 @@ def get_header_value( headers: dict[str, Any], name: str, default_value: str, - case_sensitive: bool, + case_sensitive: bool | None = None, ) -> str: ... @@ -34,23 +34,23 @@ def get_header_value( def get_header_value( headers: dict[str, Any], name: str, - default_value: Optional[str], - case_sensitive: bool, -) -> Optional[str]: ... + default_value: str | None = None, + case_sensitive: bool | None = None, +) -> str | None: ... def get_header_value( headers: dict[str, Any], name: str, - default_value: Optional[str], - case_sensitive: bool, -) -> Optional[str]: + default_value: str | None = None, + case_sensitive: bool | None = None, +) -> str | None: """ Get the value of a header by its name. Parameters ---------- - headers: Dict[str, str] + headers: dict[str, str] The dictionary of headers. name: str The name of the header to retrieve. @@ -82,7 +82,7 @@ def get_header_value( @overload def get_query_string_value( - query_string_parameters: Dict[str, str] | None, + query_string_parameters: dict[str, str] | None, name: str, default_value: str, ) -> str: ... @@ -90,17 +90,17 @@ def get_query_string_value( @overload def get_query_string_value( - query_string_parameters: Dict[str, str] | None, + query_string_parameters: dict[str, str] | None, name: str, - default_value: Optional[str] = None, -) -> Optional[str]: ... + default_value: str | None = None, +) -> str | None: ... def get_query_string_value( - query_string_parameters: Dict[str, str] | None, + query_string_parameters: dict[str, str] | None, name: str, - default_value: Optional[str] = None, -) -> Optional[str]: + default_value: str | None = None, +) -> str | None: """ Retrieves the value of a query string parameter specified by the given name. @@ -121,10 +121,10 @@ def get_query_string_value( def get_multi_value_query_string_values( - multi_value_query_string_parameters: Dict[str, List[str]] | None, + multi_value_query_string_parameters: dict[str, list[str]] | None, name: str, - default_values: Optional[List[str]] = None, -) -> List[str]: + default_values: list[str] | None = None, +) -> list[str]: """ Retrieves the values of a multi-value string parameters specified by the given name. @@ -132,12 +132,12 @@ def get_multi_value_query_string_values( ---------- name: str The name of the query string parameter to retrieve. - default_value: List[str], optional + default_value: list[str], optional The default value to return if the parameter is not found. Defaults to None. Returns ------- - List[str]. optional + list[str]. optional The values of the query string parameter if found, or the default values if not found. """ diff --git a/examples/event_handler_graphql/src/custom_models.py b/examples/event_handler_graphql/src/custom_models.py index ae2f0180e15..61e03318d14 100644 --- a/examples/event_handler_graphql/src/custom_models.py +++ b/examples/event_handler_graphql/src/custom_models.py @@ -26,11 +26,11 @@ class Location(TypedDict, total=False): class MyCustomModel(AppSyncResolverEvent): @property def country_viewer(self) -> str: - return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501 + return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False) @property def api_key(self) -> str: - return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501 + return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False) @app.resolver(type_name="Query", field_name="listLocations") From 37d646800447332bae0a4db827184cba92b2ab7e Mon Sep 17 00:00:00 2001 From: Nico Tonnhofer Date: Fri, 12 Apr 2024 13:38:06 +0000 Subject: [PATCH 4/5] chore: remove unused TypeVar --- aws_lambda_powertools/utilities/data_classes/common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index b0b67ac840d..5956a6cbdfd 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -2,7 +2,7 @@ import json from collections.abc import Mapping from functools import cached_property -from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, overload +from typing import Any, Callable, Dict, Iterator, List, Optional, overload from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer from aws_lambda_powertools.utilities.data_classes.shared_functions import ( @@ -11,8 +11,6 @@ get_query_string_value, ) -T = TypeVar("T") - class DictWrapper(Mapping): """Provides a single read only access to a wrapper dict""" From f190a134e30527a412b333154b460b2b8589b7d0 Mon Sep 17 00:00:00 2001 From: Nico Tonnhofer Date: Fri, 12 Apr 2024 15:20:51 +0000 Subject: [PATCH 5/5] chore: review findings --- .../utilities/data_classes/kafka_event.py | 4 ++-- .../data_classes/shared_functions.py | 22 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index f30c226daf3..f20c5254730 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -74,7 +74,7 @@ def get_header_value( self, name: str, default_value: str, - case_sensitive: bool = False, + case_sensitive: bool = True, ) -> str: ... @overload @@ -82,7 +82,7 @@ def get_header_value( self, name: str, default_value: Optional[str] = None, - case_sensitive: bool = False, + case_sensitive: bool = True, ) -> Optional[str]: ... def get_header_value( diff --git a/aws_lambda_powertools/utilities/data_classes/shared_functions.py b/aws_lambda_powertools/utilities/data_classes/shared_functions.py index 216ff08fe2d..2df833e54e8 100644 --- a/aws_lambda_powertools/utilities/data_classes/shared_functions.py +++ b/aws_lambda_powertools/utilities/data_classes/shared_functions.py @@ -1,7 +1,7 @@ from __future__ import annotations import base64 -from typing import Any, overload +from typing import Any, Dict, overload def base64_decode(value: str) -> str: @@ -26,7 +26,7 @@ def get_header_value( headers: dict[str, Any], name: str, default_value: str, - case_sensitive: bool | None = None, + case_sensitive: bool | None = False, ) -> str: ... @@ -35,7 +35,7 @@ def get_header_value( headers: dict[str, Any], name: str, default_value: str | None = None, - case_sensitive: bool | None = None, + case_sensitive: bool | None = False, ) -> str | None: ... @@ -43,21 +43,21 @@ def get_header_value( headers: dict[str, Any], name: str, default_value: str | None = None, - case_sensitive: bool | None = None, + case_sensitive: bool | None = False, ) -> str | None: """ Get the value of a header by its name. Parameters ---------- - headers: dict[str, str] + headers: Dict[str, str] The dictionary of headers. name: str The name of the header to retrieve. default_value: str, optional The default value to return if the header is not found. Default is None. case_sensitive: bool, optional - Indicates whether the header name should be case-sensitive. Default is None. + Indicates whether the header name should be case-sensitive. Default is False. Returns ------- @@ -82,7 +82,7 @@ def get_header_value( @overload def get_query_string_value( - query_string_parameters: dict[str, str] | None, + query_string_parameters: Dict[str, str] | None, name: str, default_value: str, ) -> str: ... @@ -90,14 +90,14 @@ def get_query_string_value( @overload def get_query_string_value( - query_string_parameters: dict[str, str] | None, + query_string_parameters: Dict[str, str] | None, name: str, default_value: str | None = None, ) -> str | None: ... def get_query_string_value( - query_string_parameters: dict[str, str] | None, + query_string_parameters: Dict[str, str] | None, name: str, default_value: str | None = None, ) -> str | None: @@ -121,7 +121,7 @@ def get_query_string_value( def get_multi_value_query_string_values( - multi_value_query_string_parameters: dict[str, list[str]] | None, + multi_value_query_string_parameters: Dict[str, list[str]] | None, name: str, default_values: list[str] | None = None, ) -> list[str]: @@ -137,7 +137,7 @@ def get_multi_value_query_string_values( Returns ------- - list[str]. optional + List[str]. optional The values of the query string parameter if found, or the default values if not found. """