From 0f5f3b3a51154c9c2716253f9480a53e4204e7dd Mon Sep 17 00:00:00 2001 From: Joel Tetrault Date: Wed, 19 Feb 2020 10:05:22 -0600 Subject: [PATCH] changes to make app.current_request thread safe when running chalice locally - fixes race conditions that can occur when chalice is being run locally and it handling multiple concurrent requests --- chalice/local.py | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/chalice/local.py b/chalice/local.py index f6ce45102f..aa9f6bd3ff 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -16,7 +16,16 @@ from six.moves.BaseHTTPServer import HTTPServer from six.moves.BaseHTTPServer import BaseHTTPRequestHandler from six.moves.socketserver import ThreadingMixIn -from typing import List, Any, Dict, Tuple, Callable, Optional, Union # noqa +from typing import ( + List, + Any, + Dict, + Tuple, + Callable, + Optional, + Union, + cast, +) # noqa from chalice.app import Chalice # noqa from chalice.app import CORSConfig # noqa @@ -47,7 +56,9 @@ def time(self): def create_local_server(app_obj, config, host, port): # type: (Chalice, Config, str, int) -> LocalDevServer - return LocalDevServer(app_obj, config, host, port) + local_app_obj = LocalChalice(app_obj) + casted_local_app_obj = cast(Chalice, local_app_obj) + return LocalDevServer(casted_local_app_obj, config, host, port) class LocalARNBuilder(object): @@ -661,3 +672,32 @@ def shutdown(self): # type: () -> None if self._server is not None: self._server.shutdown() + + +class LocalChalice(object): + def __init__(self, chalice): + # type: (Chalice) -> None + self._current_request_lookup = {} # type: Dict[int, Optional[Request]] + self._chalice = chalice + + @property + def current_request(self): # noqa + # type: () -> Optional[Request] + thread_id = threading.current_thread().ident + assert thread_id is not None + return self._current_request_lookup.get(thread_id, None) + + @current_request.setter + def current_request(self, value): # noqa + # type: (Optional[Request]) -> None + thread_id = threading.current_thread().ident + assert thread_id is not None + self._current_request_lookup[thread_id] = value + + def __getattr__(self, name): + # type: (str) -> Any + return getattr(self._chalice, name) + + def __call__(self, *args, **kwargs): + # type: (Any, Any) -> Any + return self._chalice(*args, **kwargs)