From e09fb83ad88c02b913f4ceac51b47b4a3d198583 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 21 May 2024 19:18:11 +0100 Subject: [PATCH 1/4] Fix CORS when working with multiheaders --- .../event_handler/api_gateway.py | 30 +++++++++++++++++-- docs/index.md | 10 +++---- .../install/arm64/{cdk.py => cdk_arm64.py} | 0 .../arm64/{pulumi.py => pulumi_arm64.py} | 0 .../install/sar/{cdk.py => cdk_sar.py} | 0 .../install/x86_64/{cdk.py => cdk_x86.py} | 0 .../x86_64/{pulumi.py => pulumi_x86.py} | 0 mypy.ini | 6 ++++ tests/events/apiGatewayProxyEvent.json | 3 ++ 9 files changed, 42 insertions(+), 7 deletions(-) rename examples/homepage/install/arm64/{cdk.py => cdk_arm64.py} (100%) rename examples/homepage/install/arm64/{pulumi.py => pulumi_arm64.py} (100%) rename examples/homepage/install/sar/{cdk.py => cdk_sar.py} (100%) rename examples/homepage/install/x86_64/{cdk.py => cdk_x86.py} (100%) rename examples/homepage/install/x86_64/{pulumi.py => pulumi_x86.py} (100%) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 26da85679bc..9d177035bdd 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -217,6 +217,30 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: headers["Access-Control-Allow-Credentials"] = "true" return headers + @staticmethod + def extract_origin_header(resolver_headers: Dict): + """ + Extracts the 'origin' or 'Origin' header from the provided resolver headers. + + The 'origin' or 'Origin' header can be either a single header or a multi-header. + + Args: + resolver_headers (Dict): A dictionary containing the headers. + + Returns: + Union[str, List[str], None]: The value(s) of the 'origin' or 'Origin' header. + If the header is a single header, a string is returned. + If the header is a multi-header, a list of strings is returned. + If the header is not present, None is returned. + """ + resolved_header = resolver_headers.get("origin") or resolver_headers.get("Origin") + if isinstance(resolved_header, str): + return resolved_header + if isinstance(resolved_header, list): + return resolved_header[0] + + return resolved_header + class Response(Generic[ResponseT]): """Response data class that provides greater control over what is returned from the proxy event""" @@ -782,7 +806,8 @@ def __init__( def _add_cors(self, event: ResponseEventT, cors: CORSConfig): """Update headers to include the configured Access-Control headers""" - self.response.headers.update(cors.to_dict(event.get_header_value("Origin"))) + extracted_origin_header = cors.extract_origin_header(event.resolved_headers_field) + self.response.headers.update(cors.to_dict(extracted_origin_header)) def _add_cache_control(self, cache_control: str): """Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used.""" @@ -2129,7 +2154,8 @@ def _not_found(self, method: str) -> ResponseBuilder: headers = {} if self._cors: logger.debug("CORS is enabled, updating headers.") - headers.update(self._cors.to_dict(self.current_event.get_header_value("Origin"))) + extracted_origin_header = self._cors.extract_origin_header(self.current_event.resolved_headers_field) + headers.update(self._cors.to_dict(extracted_origin_header)) if method == "OPTIONS": logger.debug("Pre-flight request detected. Returning CORS with null response") diff --git a/docs/index.md b/docs/index.md index 660bbe71c3f..f7d391d23de 100644 --- a/docs/index.md +++ b/docs/index.md @@ -89,7 +89,7 @@ You can install Powertools for AWS Lambda (Python) using your favorite dependenc === "CDK" ```python hl_lines="13 19" - --8<-- "examples/homepage/install/x86_64/cdk.py" + --8<-- "examples/homepage/install/x86_64/cdk_x86.py" ``` === "Terraform" @@ -101,7 +101,7 @@ You can install Powertools for AWS Lambda (Python) using your favorite dependenc === "Pulumi" ```python hl_lines="21-27" - --8<-- "examples/homepage/install/x86_64/pulumi.py" + --8<-- "examples/homepage/install/x86_64/pulumi_x86.py" ``` === "Amplify" @@ -127,7 +127,7 @@ You can install Powertools for AWS Lambda (Python) using your favorite dependenc === "CDK" ```python hl_lines="13 19" - --8<-- "examples/homepage/install/arm64/cdk.py" + --8<-- "examples/homepage/install/arm64/cdk_arm64.py" ``` === "Terraform" @@ -139,7 +139,7 @@ You can install Powertools for AWS Lambda (Python) using your favorite dependenc === "Pulumi" ```python hl_lines="21-27" - --8<-- "examples/homepage/install/arm64/pulumi.py" + --8<-- "examples/homepage/install/arm64/pulumi_arm64.py" ``` === "Amplify" @@ -275,7 +275,7 @@ Compared with the [public Layer ARN](#lambda-layer) option, SAR allows you to ch === "CDK" ```python hl_lines="7 16-20 23-27" - --8<-- "examples/homepage/install/sar/cdk.py" + --8<-- "examples/homepage/install/sar/cdk_sar.py" ``` === "Terraform" diff --git a/examples/homepage/install/arm64/cdk.py b/examples/homepage/install/arm64/cdk_arm64.py similarity index 100% rename from examples/homepage/install/arm64/cdk.py rename to examples/homepage/install/arm64/cdk_arm64.py diff --git a/examples/homepage/install/arm64/pulumi.py b/examples/homepage/install/arm64/pulumi_arm64.py similarity index 100% rename from examples/homepage/install/arm64/pulumi.py rename to examples/homepage/install/arm64/pulumi_arm64.py diff --git a/examples/homepage/install/sar/cdk.py b/examples/homepage/install/sar/cdk_sar.py similarity index 100% rename from examples/homepage/install/sar/cdk.py rename to examples/homepage/install/sar/cdk_sar.py diff --git a/examples/homepage/install/x86_64/cdk.py b/examples/homepage/install/x86_64/cdk_x86.py similarity index 100% rename from examples/homepage/install/x86_64/cdk.py rename to examples/homepage/install/x86_64/cdk_x86.py diff --git a/examples/homepage/install/x86_64/pulumi.py b/examples/homepage/install/x86_64/pulumi_x86.py similarity index 100% rename from examples/homepage/install/x86_64/pulumi.py rename to examples/homepage/install/x86_64/pulumi_x86.py diff --git a/mypy.ini b/mypy.ini index 5fcb1533707..3c5859c0bb0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -12,6 +12,12 @@ disable_error_code = annotation-unchecked [mypy-jmespath] ignore_missing_imports=True +[mypy-pulumi.*] +ignore_missing_imports=True + +[mypy-pulumi_aws.*] +ignore_missing_imports=True + [mypy-aws_encryption_sdk.*] ignore_missing_imports=True diff --git a/tests/events/apiGatewayProxyEvent.json b/tests/events/apiGatewayProxyEvent.json index da814c91100..07dd89c2673 100644 --- a/tests/events/apiGatewayProxyEvent.json +++ b/tests/events/apiGatewayProxyEvent.json @@ -12,6 +12,9 @@ "Header1": [ "value1" ], + "Origin": [ + "https://aws.amazon.com" + ], "Header2": [ "value1", "value2" From c2aa8722ddba1ce9866657165d0e9c5b1f5c837d Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 21 May 2024 21:48:10 +0100 Subject: [PATCH 2/4] Addressing feedback --- aws_lambda_powertools/event_handler/api_gateway.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 9d177035bdd..1428966f861 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -58,6 +58,7 @@ VPCLatticeEventV2, ) from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent +from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -228,12 +229,9 @@ def extract_origin_header(resolver_headers: Dict): resolver_headers (Dict): A dictionary containing the headers. Returns: - Union[str, List[str], None]: The value(s) of the 'origin' or 'Origin' header. - If the header is a single header, a string is returned. - If the header is a multi-header, a list of strings is returned. - If the header is not present, None is returned. + Optional[str]: The value(s) of the origin header or None. """ - resolved_header = resolver_headers.get("origin") or resolver_headers.get("Origin") + resolved_header = get_header_value(resolver_headers, "origin", None, case_sensitive=False) if isinstance(resolved_header, str): return resolved_header if isinstance(resolved_header, list): From 05a243907dd5bee75a79602bfa77ef11ce229105 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 22 May 2024 10:56:41 +0100 Subject: [PATCH 3/4] Addressing feedback --- .../event_handler/api_gateway.py | 28 ++----------------- aws_lambda_powertools/event_handler/util.py | 26 +++++++++++++++++ 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1428966f861..abbeadc5c41 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -43,7 +43,7 @@ validation_error_definition, validation_error_response_definition, ) -from aws_lambda_powertools.event_handler.util import _FrozenDict +from aws_lambda_powertools.event_handler.util import _FrozenDict, extract_origin_header from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder @@ -58,7 +58,6 @@ VPCLatticeEventV2, ) from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent -from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -218,27 +217,6 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: headers["Access-Control-Allow-Credentials"] = "true" return headers - @staticmethod - def extract_origin_header(resolver_headers: Dict): - """ - Extracts the 'origin' or 'Origin' header from the provided resolver headers. - - The 'origin' or 'Origin' header can be either a single header or a multi-header. - - Args: - resolver_headers (Dict): A dictionary containing the headers. - - Returns: - Optional[str]: The value(s) of the origin header or None. - """ - resolved_header = get_header_value(resolver_headers, "origin", None, case_sensitive=False) - if isinstance(resolved_header, str): - return resolved_header - if isinstance(resolved_header, list): - return resolved_header[0] - - return resolved_header - class Response(Generic[ResponseT]): """Response data class that provides greater control over what is returned from the proxy event""" @@ -804,7 +782,7 @@ def __init__( def _add_cors(self, event: ResponseEventT, cors: CORSConfig): """Update headers to include the configured Access-Control headers""" - extracted_origin_header = cors.extract_origin_header(event.resolved_headers_field) + extracted_origin_header = extract_origin_header(event.resolved_headers_field) self.response.headers.update(cors.to_dict(extracted_origin_header)) def _add_cache_control(self, cache_control: str): @@ -2152,7 +2130,7 @@ def _not_found(self, method: str) -> ResponseBuilder: headers = {} if self._cors: logger.debug("CORS is enabled, updating headers.") - extracted_origin_header = self._cors.extract_origin_header(self.current_event.resolved_headers_field) + extracted_origin_header = extract_origin_header(self.current_event.resolved_headers_field) headers.update(self._cors.to_dict(extracted_origin_header)) if method == "OPTIONS": diff --git a/aws_lambda_powertools/event_handler/util.py b/aws_lambda_powertools/event_handler/util.py index 2832f8102ee..3a2aefd9c20 100644 --- a/aws_lambda_powertools/event_handler/util.py +++ b/aws_lambda_powertools/event_handler/util.py @@ -1,3 +1,8 @@ +from typing import Any, Dict + +from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value + + class _FrozenDict(dict): """ A dictionary that can be used as a key in another dictionary. @@ -11,3 +16,24 @@ class _FrozenDict(dict): def __hash__(self): return hash(frozenset(self.keys())) + + +def extract_origin_header(resolver_headers: Dict[str, Any]): + """ + Extracts the 'origin' or 'Origin' header from the provided resolver headers. + + The 'origin' or 'Origin' header can be either a single header or a multi-header. + + Args: + resolver_headers (Dict): A dictionary containing the headers. + + Returns: + Optional[str]: The value(s) of the origin header or None. + """ + resolved_header = get_header_value(resolver_headers, "origin", None, case_sensitive=False) + if isinstance(resolved_header, str): + return resolved_header + if isinstance(resolved_header, list): + return resolved_header[0] + + return resolved_header From e62acb9bd6e107883ec0b64a63b3eb492a5fa0a1 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 22 May 2024 11:15:38 +0100 Subject: [PATCH 4/4] Addressing Heitor's feedback --- aws_lambda_powertools/event_handler/util.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/util.py b/aws_lambda_powertools/event_handler/util.py index 3a2aefd9c20..6f2caf10858 100644 --- a/aws_lambda_powertools/event_handler/util.py +++ b/aws_lambda_powertools/event_handler/util.py @@ -30,9 +30,12 @@ def extract_origin_header(resolver_headers: Dict[str, Any]): Returns: Optional[str]: The value(s) of the origin header or None. """ - resolved_header = get_header_value(resolver_headers, "origin", None, case_sensitive=False) - if isinstance(resolved_header, str): - return resolved_header + resolved_header = get_header_value( + headers=resolver_headers, + name="origin", + default_value=None, + case_sensitive=False, + ) if isinstance(resolved_header, list): return resolved_header[0]