Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(typing): improve overloads to ensure the return type follows the default_value type #4114

Merged
merged 5 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,17 @@ def path_parameters(self) -> Optional[Dict[str, str]]:
def stage_variables(self) -> Optional[Dict[str, str]]:
return self.get("stageVariables")

@overload
def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, overload

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
Expand Down Expand Up @@ -214,6 +214,22 @@ def stash(self) -> Optional[dict]:
a pipeline resolver."""
return self.get("stash")

@overload
def get_header_value(
self,
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
Expand Down
6 changes: 6 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ def http_method(self) -> str:
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
return self["httpMethod"]

@overload
def get_query_string_value(self, name: str, default_value: str) -> str: ...

@overload
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ...

def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]:
"""Get query string value by name

Expand Down
20 changes: 18 additions & 2 deletions aws_lambda_powertools/utilities/data_classes/kafka_event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
from functools import cached_property
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, overload

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
Expand Down Expand Up @@ -69,10 +69,26 @@ def decoded_headers(self) -> Dict[str, bytes]:
"""Decodes the headers as a single dictionary."""
return {k: bytes(v) for chunk in self.headers for k, v in chunk.items()}

@overload
def get_header_value(
self,
name: str,
default_value: Optional[Any] = None,
default_value: str,
case_sensitive: bool = True,
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = True,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = True,
) -> Optional[str]:
"""Get a decoded header value by name."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, Optional, overload

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
Expand Down Expand Up @@ -73,6 +73,22 @@ def headers(self) -> Dict[str, str]:
The case of the original headers is retained in this map."""
return self["headers"]

@overload
def get_header_value(
self,
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import base64
from typing import Any, Dict
from typing import Any, Dict, overload


def base64_decode(value: str) -> str:
Expand All @@ -21,11 +21,29 @@ def base64_decode(value: str) -> str:
return base64.b64decode(value).decode("UTF-8")


@overload
def get_header_value(
headers: dict[str, Any],
name: str,
default_value: str | None,
case_sensitive: bool | None,
default_value: str,
case_sensitive: bool | None = False,
) -> str: ...


@overload
def get_header_value(
headers: dict[str, Any],
name: str,
default_value: str | None = None,
case_sensitive: bool | None = False,
) -> str | None: ...


def get_header_value(
headers: dict[str, Any],
name: str,
default_value: str | None = None,
case_sensitive: bool | None = False,
) -> str | None:
"""
Get the value of a header by its name.
Expand All @@ -39,7 +57,7 @@ def get_header_value(
default_value: str, optional
The default value to return if the header is not found. Default is None.
case_sensitive: bool, optional
Indicates whether the header name should be case-sensitive. Default is None.
Indicates whether the header name should be case-sensitive. Default is False.

Returns
-------
Expand All @@ -62,6 +80,22 @@ def get_header_value(
)


@overload
def get_query_string_value(
query_string_parameters: Dict[str, str] | None,
name: str,
default_value: str,
) -> str: ...


@overload
def get_query_string_value(
query_string_parameters: Dict[str, str] | None,
name: str,
default_value: str | None = None,
) -> str | None: ...


def get_query_string_value(
query_string_parameters: Dict[str, str] | None,
name: str,
Expand Down
6 changes: 6 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def http_method(self) -> str:
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
return self["method"]

@overload
def get_query_string_value(self, name: str, default_value: str) -> str: ...

@overload
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ...

def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]:
"""Get query string value by name

Expand Down
4 changes: 2 additions & 2 deletions examples/event_handler_graphql/src/custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class Location(TypedDict, total=False):
class MyCustomModel(AppSyncResolverEvent):
@property
def country_viewer(self) -> str:
return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501
return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False)

@property
def api_key(self) -> str:
return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501
return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False)


@app.resolver(type_name="Query", field_name="listLocations")
Expand Down