diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py index 0b1aec43e8a..073789eabe6 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py @@ -283,6 +283,17 @@ def path_parameters(self) -> Optional[Dict[str, str]]: def stage_variables(self) -> Optional[Dict[str, str]]: return self.get("stageVariables") + @overload + def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: Optional[bool] = False, + ) -> Optional[str]: ... + def get_header_value( self, name: str, diff --git a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py index 14973009fb9..edb590e7770 100644 --- a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py +++ b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, overload from aws_lambda_powertools.utilities.data_classes.common import DictWrapper from aws_lambda_powertools.utilities.data_classes.shared_functions import ( @@ -214,6 +214,22 @@ def stash(self) -> Optional[dict]: a pipeline resolver.""" return self.get("stash") + @overload + def get_header_value( + self, + name: str, + default_value: str, + case_sensitive: Optional[bool] = False, + ) -> str: ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: Optional[bool] = False, + ) -> Optional[str]: ... + def get_header_value( self, name: str, diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index ffca15cc318..5956a6cbdfd 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -172,6 +172,12 @@ def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" return self["httpMethod"] + @overload + def get_query_string_value(self, name: str, default_value: str) -> str: ... + + @overload + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ... + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: """Get query string value by name diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index d3d1425f0f2..f20c5254730 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -1,6 +1,6 @@ import base64 from functools import cached_property -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional, overload from aws_lambda_powertools.utilities.data_classes.common import DictWrapper from aws_lambda_powertools.utilities.data_classes.shared_functions import ( @@ -69,10 +69,26 @@ def decoded_headers(self) -> Dict[str, bytes]: """Decodes the headers as a single dictionary.""" return {k: bytes(v) for chunk in self.headers for k, v in chunk.items()} + @overload def get_header_value( self, name: str, - default_value: Optional[Any] = None, + default_value: str, + case_sensitive: bool = True, + ) -> str: ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: bool = True, + ) -> Optional[str]: ... + + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, case_sensitive: bool = True, ) -> Optional[str]: """Get a decoded header value by name.""" diff --git a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py index a7953c32c59..4338a2f61e7 100644 --- a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py +++ b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, overload from aws_lambda_powertools.utilities.data_classes.common import DictWrapper from aws_lambda_powertools.utilities.data_classes.shared_functions import ( @@ -73,6 +73,22 @@ def headers(self) -> Dict[str, str]: The case of the original headers is retained in this map.""" return self["headers"] + @overload + def get_header_value( + self, + name: str, + default_value: str, + case_sensitive: Optional[bool] = False, + ) -> str: ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: Optional[bool] = False, + ) -> Optional[str]: ... + def get_header_value( self, name: str, diff --git a/aws_lambda_powertools/utilities/data_classes/shared_functions.py b/aws_lambda_powertools/utilities/data_classes/shared_functions.py index 43a3aad281b..2df833e54e8 100644 --- a/aws_lambda_powertools/utilities/data_classes/shared_functions.py +++ b/aws_lambda_powertools/utilities/data_classes/shared_functions.py @@ -1,7 +1,7 @@ from __future__ import annotations import base64 -from typing import Any, Dict +from typing import Any, Dict, overload def base64_decode(value: str) -> str: @@ -21,11 +21,29 @@ def base64_decode(value: str) -> str: return base64.b64decode(value).decode("UTF-8") +@overload def get_header_value( headers: dict[str, Any], name: str, - default_value: str | None, - case_sensitive: bool | None, + default_value: str, + case_sensitive: bool | None = False, +) -> str: ... + + +@overload +def get_header_value( + headers: dict[str, Any], + name: str, + default_value: str | None = None, + case_sensitive: bool | None = False, +) -> str | None: ... + + +def get_header_value( + headers: dict[str, Any], + name: str, + default_value: str | None = None, + case_sensitive: bool | None = False, ) -> str | None: """ Get the value of a header by its name. @@ -39,7 +57,7 @@ def get_header_value( default_value: str, optional The default value to return if the header is not found. Default is None. case_sensitive: bool, optional - Indicates whether the header name should be case-sensitive. Default is None. + Indicates whether the header name should be case-sensitive. Default is False. Returns ------- @@ -62,6 +80,22 @@ def get_header_value( ) +@overload +def get_query_string_value( + query_string_parameters: Dict[str, str] | None, + name: str, + default_value: str, +) -> str: ... + + +@overload +def get_query_string_value( + query_string_parameters: Dict[str, str] | None, + name: str, + default_value: str | None = None, +) -> str | None: ... + + def get_query_string_value( query_string_parameters: Dict[str, str] | None, name: str, diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index f997d4b3f04..e2612e0334f 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -47,6 +47,12 @@ def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" return self["method"] + @overload + def get_query_string_value(self, name: str, default_value: str) -> str: ... + + @overload + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ... + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: """Get query string value by name diff --git a/examples/event_handler_graphql/src/custom_models.py b/examples/event_handler_graphql/src/custom_models.py index ae2f0180e15..61e03318d14 100644 --- a/examples/event_handler_graphql/src/custom_models.py +++ b/examples/event_handler_graphql/src/custom_models.py @@ -26,11 +26,11 @@ class Location(TypedDict, total=False): class MyCustomModel(AppSyncResolverEvent): @property def country_viewer(self) -> str: - return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501 + return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False) @property def api_key(self) -> str: - return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501 + return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False) @app.resolver(type_name="Query", field_name="listLocations")