Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(apigateway): add exception_handler support #898

Merged
merged 6 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from enum import Enum
from functools import partial
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import ServiceError
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice
from aws_lambda_powertools.shared.json_encoder import Encoder
Expand Down Expand Up @@ -435,6 +435,7 @@ def __init__(
self._proxy_type = proxy_type
self._routes: List[Route] = []
self._route_keys: List[str] = []
self._exception_handlers: Dict[Union[int, Type], Callable] = {}
self._cors = cors
self._cors_enabled: bool = cors is not None
self._cors_methods: Set[str] = {"OPTIONS"}
Expand Down Expand Up @@ -596,6 +597,11 @@ def _not_found(self, method: str) -> ResponseBuilder:
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))

# Allow for custom exception handlers
handler = self._exception_handlers.get(404)
if handler:
return ResponseBuilder(handler(NotFoundError()))

return ResponseBuilder(
Response(
status_code=HTTPStatus.NOT_FOUND.value,
Expand All @@ -609,16 +615,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
try:
return ResponseBuilder(self._to_response(route.func(**args)), route)
except ServiceError as e:
return ResponseBuilder(
Response(
status_code=e.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": e.status_code, "message": e.msg}),
),
route,
)
except Exception:
except Exception as exc:
response_builder = self._call_exception_handler(exc, route)
if response_builder:
return response_builder

if self._debug:
# If the user has turned on debug mode,
# we'll let the original exception propagate so
Expand All @@ -628,8 +629,10 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
status_code=500,
content_type=content_types.TEXT_PLAIN,
body="".join(traceback.format_exc()),
)
),
route,
)

raise

def _to_response(self, result: Union[Dict, Response]) -> Response:
Expand Down Expand Up @@ -676,6 +679,38 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None

self.route(*route)(func)

def not_found(self, func: Callable):
return self.exception_handler(404)(func)

def exception_handler(self, exc_class_or_status_code: Union[int, Type[Exception]]):
def register_exception_handler(func: Callable):
self._exception_handlers[exc_class_or_status_code] = func

return register_exception_handler

def _lookup_exception_handler(self, exp: Exception) -> Optional[Callable]:
for cls in type(exp).__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None

def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]:
handler = self._lookup_exception_handler(exp)
if handler:
return ResponseBuilder(handler(exp), route)

if isinstance(exp, ServiceError):
return ResponseBuilder(
Response(
status_code=exp.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
),
route,
)

return None


class Router(BaseRouter):
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""
Expand Down
75 changes: 74 additions & 1 deletion tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def patch_func():
def handler(event, context):
return app.resolve(event, context)

# Also check check the route configurations
# Also check the route configurations
routes = app._routes
assert len(routes) == 5
for route in routes:
Expand Down Expand Up @@ -1076,3 +1076,76 @@ def foo():

assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON


def test_exception_handler():
# GIVEN a resolver with an exception handler defined for ValueError
app = ApiGatewayResolver()

@app.exception_handler(ValueError)
def handle_value_error(ex: ValueError):
print(f"request path is '{app.current_event.path}'")
return Response(
status_code=418,
content_type=content_types.TEXT_HTML,
body=str(ex),
)

@app.get("/my/path")
def get_lambda() -> Response:
raise ValueError("Foo!")

# WHEN calling the event handler
# AND a ValueError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 418
assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
assert result["body"] == "Foo!"


def test_exception_handler_service_error():
# GIVEN
app = ApiGatewayResolver()

@app.exception_handler(ServiceError)
def service_error(ex: ServiceError):
print(ex.msg)
return Response(
status_code=ex.status_code,
content_type=content_types.APPLICATION_JSON,
body="CUSTOM ERROR FORMAT",
)

@app.get("/my/path")
def get_lambda() -> Response:
raise InternalServerError("Something sensitive")

# WHEN calling the event handler
# AND a ServiceError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 500
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
assert result["body"] == "CUSTOM ERROR FORMAT"


def test_exception_handler_not_found():
# GIVEN a resolver with an exception handler defined for a 404 not found
app = ApiGatewayResolver()

@app.not_found
def handle_not_found(exc: NotFoundError) -> Response:
assert isinstance(exc, NotFoundError)
return Response(status_code=404, content_type=content_types.TEXT_PLAIN, body="I am a teapot!")

# WHEN calling the event handler
# AND not route is found
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 404
assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN
assert result["body"] == "I am a teapot!"