Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(event_handler): mark API operation as deprecated for OpenAPI documentation #5732

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cfa9305
Add deprecated parameter with default to BaseRouter.get
tcysin Dec 12, 2024
e0440dc
Add parameter with default to BaseRouter.route
tcysin Dec 12, 2024
f6c9bbb
Pass deprecated param from .get() into .route()
tcysin Dec 12, 2024
9df0059
Add param and pass along for post, put, delete, patch, head
tcysin Dec 12, 2024
d8ca3ff
Add param and pass along for ApiGatewayRestResolver.route
tcysin Dec 12, 2024
9d2512b
Ditto for Route.__init__, use when creating operation metadata
tcysin Dec 12, 2024
669d655
Add param and pass along in ApiGatewayResolver.route
tcysin Dec 12, 2024
e183b51
Add param and pass along in Router.route, workaround for include_router
tcysin Dec 12, 2024
37313d5
Functional tests
tcysin Dec 12, 2024
0dcef6e
Formatting
tcysin Dec 12, 2024
213dbbb
Merge branch 'develop' into feat/mark-api-operation-as-deprecated
leandrodamascena Dec 15, 2024
67cf2c9
Refactor to use defaultdict
tcysin Dec 16, 2024
a557013
Move deprecated operation tests into separate test case
tcysin Dec 17, 2024
f74b384
Simplify test case
tcysin Dec 17, 2024
7d66193
Merge branch 'develop' into feat/mark-api-operation-as-deprecated
tcysin Dec 17, 2024
abd503e
Merge branch 'develop' into feat/mark-api-operation-as-deprecated
leandrodamascena Dec 19, 2024
3886600
Put 'deprecated' param before 'middlewares'
tcysin Dec 19, 2024
3375e1d
Remove workaround
tcysin Dec 19, 2024
6e9a5da
Add test case for deprecated POST operation
tcysin Dec 19, 2024
d3ead97
Add 'deprecated' param to BedrockAgentResolver methods
tcysin Dec 19, 2024
9968d68
Small changes + trigger pipeline
leandrodamascena Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
import zlib
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from functools import partial
from http import HTTPStatus
Expand Down Expand Up @@ -310,6 +311,7 @@ def __init__(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable[..., Response]] | None = None,
deprecated: bool = False,
):
"""

Expand Down Expand Up @@ -350,6 +352,8 @@ def __init__(
Additional OpenAPI extensions as a dictionary.
middlewares: list[Callable[..., Response]] | None
The list of route middlewares to be called in order.
deprecated: bool
Whether or not to mark this route as deprecated in the OpenAPI schema
"""
self.method = method.upper()
self.path = "/" if path.strip() == "" else path
Expand All @@ -374,6 +378,7 @@ def __init__(
self.openapi_extensions = openapi_extensions
self.middlewares = middlewares or []
self.operation_id = operation_id or self._generate_operation_id()
self.deprecated = deprecated

# _middleware_stack_built is used to ensure the middleware stack is only built once.
self._middleware_stack_built = False
Expand Down Expand Up @@ -670,6 +675,10 @@ def _openapi_operation_metadata(self, operation_ids: set[str]) -> dict[str, Any]
operation_ids.add(self.operation_id)
operation["operationId"] = self.operation_id

# Mark as deprecated if necessary
if self.deprecated:
operation["deprecated"] = True

return operation

@staticmethod
Expand Down Expand Up @@ -925,6 +934,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable[..., Any]] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
raise NotImplementedError()

Expand Down Expand Up @@ -985,6 +995,7 @@ def get(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable[..., Any]] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Get route decorator with GET `method`

Expand Down Expand Up @@ -1024,6 +1035,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
middlewares,
deprecated,
)

def post(
Expand All @@ -1042,6 +1054,7 @@ def post(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable[..., Any]] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Post route decorator with POST `method`

Expand Down Expand Up @@ -1082,6 +1095,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
middlewares,
deprecated,
)

def put(
Expand All @@ -1100,6 +1114,7 @@ def put(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable[..., Any]] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Put route decorator with PUT `method`

Expand Down Expand Up @@ -1140,6 +1155,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
middlewares,
deprecated,
)

def delete(
Expand All @@ -1158,6 +1174,7 @@ def delete(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable[..., Any]] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Delete route decorator with DELETE `method`

Expand Down Expand Up @@ -1197,6 +1214,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
middlewares,
deprecated,
)

def patch(
Expand All @@ -1215,6 +1233,7 @@ def patch(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Patch route decorator with PATCH `method`

Expand Down Expand Up @@ -1257,6 +1276,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
middlewares,
deprecated,
)

def head(
Expand All @@ -1275,6 +1295,7 @@ def head(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Head route decorator with HEAD `method`

Expand Down Expand Up @@ -1316,6 +1337,7 @@ def lambda_handler(event, context):
security,
openapi_extensions,
middlewares,
deprecated,
)

def _push_processed_stack_frame(self, frame: str):
Expand Down Expand Up @@ -1951,6 +1973,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable[..., Any]] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
"""Route decorator includes parameter `method`"""

Expand Down Expand Up @@ -1979,6 +2002,7 @@ def register_resolver(func: AnyCallableT) -> AnyCallableT:
security,
openapi_extensions,
middlewares,
deprecated,
)

# The more specific route wins.
Expand Down Expand Up @@ -2425,12 +2449,16 @@ def include_router(self, router: Router, prefix: str | None = None) -> None:
# Middleware store the route without prefix, so we must not include prefix when grabbing
middlewares = router._routes_with_middleware.get(route)

# Workaround to support backward-compatible interface
new_route = new_route[:-1] # positional arguments until `middlewares` parameter
deprecated: bool = route[-1] # see route_key in Router.route

# Need to use "type: ignore" here since mypy does not like a named parameter after
# tuple expansion since may cause duplicate named parameters in the function signature.
# In this case this is not possible since the tuple expansion is from a hashable source
# and the `middlewares` list is a non-hashable structure so will never be included.
# Still need to ignore for mypy checks or will cause failures (false-positive)
self.route(*new_route, middlewares=middlewares)(func) # type: ignore
self.route(*new_route, deprecated=deprecated, middlewares=middlewares)(func) # type: ignore

@staticmethod
def _get_fields_from_routes(routes: Sequence[Route]) -> list[ModelField]:
Expand Down Expand Up @@ -2471,7 +2499,7 @@ class Router(BaseRouter):

def __init__(self):
self._routes: dict[tuple, Callable] = {}
self._routes_with_middleware: dict[tuple, list[Callable]] = {}
self._routes_with_middleware: defaultdict[tuple, list[Callable]] = defaultdict(list)
self.api_resolver: BaseRouter | None = None
self.context = {} # early init as customers might add context before event resolution
self._exception_handlers: dict[type, Callable] = {}
Expand All @@ -2493,6 +2521,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable[..., Any]] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
def register_route(func: AnyCallableT) -> AnyCallableT:
# All dict keys needs to be hashable. So we'll need to do some conversions:
Expand All @@ -2517,17 +2546,13 @@ def register_route(func: AnyCallableT) -> AnyCallableT:
include_in_schema,
frozen_security,
fronzen_openapi_extensions,
deprecated,
)

# Collate Middleware for routes
if middlewares is not None:
for handler in middlewares:
if self._routes_with_middleware.get(route_key) is None:
self._routes_with_middleware[route_key] = [handler]
else:
self._routes_with_middleware[route_key].append(handler)
else:
self._routes_with_middleware[route_key] = []
self._routes_with_middleware[route_key].append(handler)

self._routes[route_key] = func

Expand Down Expand Up @@ -2599,6 +2624,7 @@ def route(
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
middlewares: list[Callable[..., Any]] | None = None,
deprecated: bool = False,
) -> Callable[[AnyCallableT], AnyCallableT]:
# NOTE: see #1552 for more context.
return super().route(
Expand All @@ -2617,6 +2643,7 @@ def route(
security,
openapi_extensions,
middlewares,
deprecated,
)

# Override _compile_regex to exclude trailing slashes for route resolution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from aws_lambda_powertools.shared.functions import resolve_env_var_choice
from aws_lambda_powertools.warnings import PowertoolsUserWarning


if TYPE_CHECKING:
from aws_lambda_powertools.metrics.provider.cloudwatch_emf.types import CloudWatchEMFOutput
from aws_lambda_powertools.metrics.types import MetricNameUnitResolution
Expand Down Expand Up @@ -295,8 +294,6 @@ def add_dimension(self, name: str, value: str) -> None:

self.dimension_set[name] = value



def add_metadata(self, key: str, value: Any) -> None:
"""Adds high cardinal metadata for metrics object

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class APIGatewayWebSocketEventIdentity(BaseModel):
source_ip: IPvAnyNetwork = Field(alias="sourceIp")
user_agent: Optional[str] = Field(None, alias="userAgent")


class APIGatewayWebSocketEventRequestContextBase(BaseModel):
extended_request_id: str = Field(alias="extendedRequestId")
request_time: str = Field(alias="requestTime")
Expand Down
14 changes: 14 additions & 0 deletions tests/functional/event_handler/_pydantic/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def handler():
get = path.get
assert get.summary == "GET /"
assert get.operationId == "handler__get"
assert get.deprecated is None

assert get.responses is not None
assert 200 in get.responses.keys()
Expand Down Expand Up @@ -388,6 +389,19 @@ def handler(user: Annotated[User, Body(description="This is a user")]):
assert request_body.content[JSON_CONTENT_TYPE].schema_.description == "This is a user"


def test_openapi_with_deprecated_operations():
app = APIGatewayRestResolver()

@app.get("/", deprecated=True)
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema()

get = schema.paths["/"].get
assert get.deprecated is True


def test_openapi_with_excluded_operations():
app = APIGatewayRestResolver()

Expand Down
18 changes: 9 additions & 9 deletions tests/unit/metrics/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import warnings

import pytest

from aws_lambda_powertools.metrics import Metrics
from aws_lambda_powertools.metrics.functions import (
extract_cloudwatch_metric_resolution_value,
extract_cloudwatch_metric_unit_value,
Expand All @@ -10,17 +12,17 @@
MetricUnitError,
)
from aws_lambda_powertools.metrics.provider.cloudwatch_emf.metric_properties import MetricResolution, MetricUnit
from aws_lambda_powertools.metrics import Metrics
from aws_lambda_powertools.warnings import PowertoolsUserWarning


@pytest.fixture
def warning_catcher(monkeypatch):
caught_warnings = []

def custom_warn(message, category=None, stacklevel=1, source=None):
caught_warnings.append(PowertoolsUserWarning(message))

monkeypatch.setattr(warnings, 'warn', custom_warn)
monkeypatch.setattr(warnings, "warn", custom_warn)
return caught_warnings


Expand Down Expand Up @@ -78,13 +80,13 @@ def test_extract_valid_cloudwatch_metric_unit_value():

def test_add_dimension_overwrite_warning(warning_catcher):
"""
Adds a dimension and then tries to add another with the same name
but a different value. Verifies if the dimension is updated with
the new value and warning is issued when an existing dimension
Adds a dimension and then tries to add another with the same name
but a different value. Verifies if the dimension is updated with
the new value and warning is issued when an existing dimension
is overwritten.
"""
metrics = Metrics(namespace="TestNamespace")

# GIVEN default dimension
dimension_name = "test-dimension"
value1 = "test-value-1"
Expand All @@ -100,5 +102,3 @@ def test_add_dimension_overwrite_warning(warning_catcher):
# AND a warning should be issued with the exact message
expected_warning = f"Dimension '{dimension_name}' has already been added. The previous value will be overwritten."
assert any(str(w) == expected_warning for w in warning_catcher)


2 changes: 1 addition & 1 deletion tests/unit/parser/_pydantic/test_apigw_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,4 @@ def test_apigw_websocket_disconnect_event():

assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"]
assert parsed_event.headers == raw_event["headers"]
assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"]
assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"]