diff --git a/src/programming/Ligare/programming/patterns/dependency_injection.py b/src/programming/Ligare/programming/patterns/dependency_injection.py index 0b555a5d..2a6be25b 100644 --- a/src/programming/Ligare/programming/patterns/dependency_injection.py +++ b/src/programming/Ligare/programming/patterns/dependency_injection.py @@ -76,7 +76,7 @@ def formatMessage(self, record: logging.LogRecord) -> dict[str, Any]: # pyright KeyError is raised if an unknown attribute is provided in the fmt_dict. """ return { - fmt_key: record.__dict__[fmt_val] + fmt_key: record.__dict__.get(fmt_val, None) for fmt_key, fmt_val in self.fmt_dict.items() } @@ -118,22 +118,24 @@ def __init__( name: str | None = None, log_level: int | str = logging.INFO, log_to_stdout: bool = False, + formatter: JSONFormatter | None = None, ) -> None: super().__init__(name, log_level, log_to_stdout) - formatter = JSONFormatter({ - "level": "levelname", - "message": "message", - "file": "pathname", - "func": "funcName", - "line": "lineno", - "loggerName": "name", - "processName": "processName", - "processID": "process", - "threadName": "threadName", - "threadID": "thread", - "timestamp": "asctime", - }) + if not formatter: + formatter = JSONFormatter({ + "level": "levelname", + "message": "message", + "file": "pathname", + "func": "funcName", + "line": "lineno", + "loggerName": "name", + "processName": "processName", + "processID": "process", + "threadName": "threadName", + "threadID": "thread", + "timestamp": "asctime", + }) handler = logging.StreamHandler() handler.formatter = formatter diff --git a/src/web/Ligare/web/middleware/consts.py b/src/web/Ligare/web/middleware/consts.py index d516dfd1..8374f3bd 100644 --- a/src/web/Ligare/web/middleware/consts.py +++ b/src/web/Ligare/web/middleware/consts.py @@ -2,7 +2,7 @@ String constants used by :ref:`Ligare.web`. """ -CORRELATION_ID_HEADER = "X-Correlation-Id" +REQUEST_ID_HEADER = "X-Correlation-Id" REQUEST_COOKIE_HEADER = "Cookie" RESPONSE_COOKIE_HEADER = "Set-Cookie" CORS_ACCESS_CONTROL_ALLOW_ORIGIN_HEADER = "Access-Control-Allow-Origin" diff --git a/src/web/Ligare/web/middleware/dependency_injection.py b/src/web/Ligare/web/middleware/dependency_injection.py index e26cb0e4..4c58702b 100644 --- a/src/web/Ligare/web/middleware/dependency_injection.py +++ b/src/web/Ligare/web/middleware/dependency_injection.py @@ -15,16 +15,68 @@ from injector import Binder, Injector, Module from Ligare.programming.dependency_injection import ConfigModule from Ligare.programming.patterns.dependency_injection import ( + JSONFormatter, JSONLoggerModule, LoggerModule, ) from Ligare.web.application import Config as AppConfig +from Ligare.web.middleware.openapi import ( + CorrelationIdMiddleware, + RequestIdMiddleware, + get_trace_id, +) from starlette.types import ASGIApp, Receive, Scope, Send from typing_extensions import override from . import RegisterMiddlewareCallback, TFlaskApp +class WebJSONFormatter(JSONFormatter): + @override + def __init__( + self, + fmt_dict: dict[str, str] | None = None, + time_format: str = "%Y-%m-%dT%H:%M:%S", + msec_format: str = "%s.%03dZ", + ): + super().__init__(fmt_dict, time_format, msec_format) + + @override + def formatMessage(self, record: logging.LogRecord) -> dict[str, Any]: + format = super().formatMessage(record) + correlation_ids = get_trace_id() + format["correlationId"] = correlation_ids.CorrelationId + format["requestId"] = correlation_ids.RequestId + return format + + +class WebJSONLoggerModule(JSONLoggerModule): + @override + def __init__( + self, + name: str | None = None, + log_level: int | str = logging.INFO, + log_to_stdout: bool = False, + formatter: JSONFormatter | None = None, + ) -> None: + formatter = WebJSONFormatter({ + "level": "levelname", + "message": "message", + "file": "pathname", + "func": "funcName", + "line": "lineno", + "loggerName": "name", + "processName": "processName", + "processID": "process", + "threadName": "threadName", + "threadID": "thread", + "timestamp": "asctime", + "correlationId": "correlationId", + "requestId": "requestId", + }) + super().__init__(name, log_level, log_to_stdout, formatter) + + class MiddlewareRoutine(Protocol): def __call__( self, scope: Scope, receive: Receive, send: Send, *args: Any @@ -53,7 +105,7 @@ def configure(self, binder: Binder) -> None: app_config = binder.injector.get(AppConfig) log_level = app_config.logging.log_level.upper() if app_config.logging.format == "JSON": - binder.install(JSONLoggerModule(self._flask_app.name, log_level)) + binder.install(WebJSONLoggerModule(self._flask_app.name, log_level)) else: binder.install(LoggerModule(self._flask_app.name, log_level)) @@ -110,6 +162,8 @@ def configure_dependencies( if isinstance(app, FlaskApp): app.add_middleware(OpenAPIEndpointDependencyInjectionMiddleware(flask_injector)) + app.add_middleware(CorrelationIdMiddleware) + app.add_middleware(RequestIdMiddleware) # For every module registered, check if any are "middleware" type modules. # if they are, they need to be registered with the application. diff --git a/src/web/Ligare/web/middleware/flask/__init__.py b/src/web/Ligare/web/middleware/flask/__init__.py index a2d536b8..c0d067f2 100644 --- a/src/web/Ligare/web/middleware/flask/__init__.py +++ b/src/web/Ligare/web/middleware/flask/__init__.py @@ -17,7 +17,6 @@ from ...config import Config from ..consts import ( CONTENT_SECURITY_POLICY_HEADER, - CORRELATION_ID_HEADER, CORS_ACCESS_CONTROL_ALLOW_CREDENTIALS_HEADER, CORS_ACCESS_CONTROL_ALLOW_METHODS_HEADER, CORS_ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, @@ -26,6 +25,7 @@ ORIGIN_HEADER, OUTGOING_RESPONSE_MESSAGE, REQUEST_COOKIE_HEADER, + REQUEST_ID_HEADER, RESPONSE_COOKIE_HEADER, ) @@ -81,7 +81,7 @@ def _get_correlation_id(log: Logger) -> str: def _get_correlation_id_from_headers(log: Logger) -> str: try: - correlation_id = request.headers.get(CORRELATION_ID_HEADER) + correlation_id = request.headers.get(REQUEST_ID_HEADER) if correlation_id: # validate format @@ -89,13 +89,13 @@ def _get_correlation_id_from_headers(log: Logger) -> str: else: correlation_id = str(uuid4()) log.info( - f'Generated new UUID "{correlation_id}" for {CORRELATION_ID_HEADER} request header.' + f'Generated new UUID "{correlation_id}" for {REQUEST_ID_HEADER} request header.' ) return correlation_id except ValueError as e: - log.warning(f"Badly formatted {CORRELATION_ID_HEADER} received in request.") + log.warning(f"Badly formatted {REQUEST_ID_HEADER} received in request.") raise e @@ -107,11 +107,11 @@ def _get_correlation_id_from_json_logging(log: Logger) -> str | None: _ = uuid.UUID(correlation_id) return correlation_id except ValueError as e: - log.warning(f"Badly formatted {CORRELATION_ID_HEADER} received in request.") + log.warning(f"Badly formatted {REQUEST_ID_HEADER} received in request.") raise e except Exception as e: log.debug( - f"Error received when getting {CORRELATION_ID_HEADER} header from `json_logging`. Possibly `json_logging` is not configured, and this is not an error.", + f"Error received when getting {REQUEST_ID_HEADER} header from `json_logging`. Possibly `json_logging` is not configured, and this is not an error.", exc_info=e, ) @@ -195,7 +195,7 @@ def _wrap_all_api_responses(response: Response, config: Config, log: Logger): config.web.security.cors.allow_methods ) - response.headers[CORRELATION_ID_HEADER] = correlation_id + response.headers[REQUEST_ID_HEADER] = correlation_id if config.web.security.csp: response.headers[CONTENT_SECURITY_POLICY_HEADER] = config.web.security.csp diff --git a/src/web/Ligare/web/middleware/openapi/__init__.py b/src/web/Ligare/web/middleware/openapi/__init__.py index 4ef282bd..406259c0 100644 --- a/src/web/Ligare/web/middleware/openapi/__init__.py +++ b/src/web/Ligare/web/middleware/openapi/__init__.py @@ -6,12 +6,21 @@ import uuid from collections.abc import Iterable from contextlib import ExitStack -from contextvars import Token +from contextvars import ContextVar, Token from logging import Logger -from typing import Any, Awaitable, Callable, Literal, TypeAlias, TypeVar, cast +from typing import ( + Any, + Awaitable, + Callable, + Literal, + NamedTuple, + NewType, + TypeAlias, + TypeVar, + cast, +) from uuid import uuid4 -import json_logging import starlette import starlette.datastructures import starlette.requests @@ -33,10 +42,10 @@ from ...config import Config from ..consts import ( CONTENT_SECURITY_POLICY_HEADER, - CORRELATION_ID_HEADER, INCOMING_REQUEST_MESSAGE, OUTGOING_RESPONSE_MESSAGE, REQUEST_COOKIE_HEADER, + REQUEST_ID_HEADER, RESPONSE_COOKIE_HEADER, ) @@ -114,12 +123,154 @@ }, ) +CorrelationId = NewType("CorrelationId", str) +RequestId = NewType("RequestId", str) + +CORRELATION_ID_CTX_KEY = "correlationId" +REQUEST_ID_CTX_KEY = "requestId" + +_correlation_id_ctx_var: ContextVar[CorrelationId | None] = ContextVar( + CORRELATION_ID_CTX_KEY, default=None +) +_request_id_ctx_var: ContextVar[RequestId | None] = ContextVar( + REQUEST_ID_CTX_KEY, default=None +) + + +class TraceId(NamedTuple): + CorrelationId: CorrelationId | None + RequestId: RequestId | None + + +def get_trace_id() -> TraceId: + return TraceId(_correlation_id_ctx_var.get(), _request_id_ctx_var.get()) + + +@final +class CorrelationIdMiddleware: + """ + Generate a Correlation ID for each request. + + https://github.com/encode/starlette/issues/420 + """ + + def __init__( + self, + app: ASGIApp, + ) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ["http", "websocket"]: + await self.app(scope, receive, send) + return + + correlation_id = _correlation_id_ctx_var.set(CorrelationId(str(uuid4()))) + + await self.app(scope, receive, send) + + _correlation_id_ctx_var.reset(correlation_id) + + +@final +class RequestIdMiddleware: + """ + Generate a Trace ID for each request. + If X-Correlation-Id is set in the request headers, that ID is used instead. + """ + + _app: ASGIApp + + def __init__(self, app: ASGIApp): + super().__init__() + self._app = app + + @inject + async def __call__( + self, scope: Scope, receive: Receive, send: Send, log: Logger + ) -> None: + if scope["type"] not in ["http", "websocket"]: + return await self._app(scope, receive, send) + + # extract the request ID from the request headers if it is set + + request = cast(MiddlewareRequestDict, scope) + request_headers = request.get("headers") + + content_type = utils.extract_content_type(request_headers) + _, encoding = utils.split_content_type(content_type) + if encoding is None: + encoding = "utf-8" + + try: + request_id_header_encoded = REQUEST_ID_HEADER.lower().encode(encoding) + + request_id: bytes | None = next( + ( + request_id + for (header, request_id) in request_headers + if header == request_id_header_encoded + ), + None, + ) + + if request_id: + # validate format + request_id_decoded = request_id.decode(encoding) + _ = uuid.UUID(request_id_decoded) + request_id_token = _request_id_ctx_var.set( + RequestId(request_id_decoded) + ) + else: + request_id_decoded = str(uuid4()) + request_id = request_id_decoded.encode(encoding) + request_headers.append(( + request_id_header_encoded, + request_id, + )) + request_id_token = _request_id_ctx_var.set( + RequestId(request_id_decoded) + ) + log.info( + f'Generated new UUID "{request_id}" for {REQUEST_ID_HEADER} request header.' + ) + except ValueError as e: + log.warning(f"Badly formatted {REQUEST_ID_HEADER} received in request.") + raise e + + async def wrapped_send(message: Any) -> None: + nonlocal scope + nonlocal send + + if message["type"] != "http.response.start": + return await send(message) + + # include the request ID in response headers + + response = cast(MiddlewareResponseDict, message) + response_headers = response["headers"] + + content_type = utils.extract_content_type(response_headers) + _, encoding = utils.split_content_type(content_type) + if encoding is None: + encoding = "utf-8" + + response_headers.append(( + request_id_header_encoded, + request_id, + )) + + return await send(message) + + await self._app(scope, receive, wrapped_send) + + _request_id_ctx_var.reset(request_id_token) + def _get_correlation_id( request: MiddlewareRequestDict, response: MiddlewareResponseDict, log: Logger ) -> str: - correlation_id = _get_correlation_id_from_json_logging(response, log) - + correlation_id = get_trace_id().CorrelationId if not correlation_id: correlation_id = _get_correlation_id_from_headers(request, response, log) @@ -131,11 +282,11 @@ def _get_correlation_id_from_headers( ) -> str: try: headers = _headers_as_dict(request) - correlation_id = headers.get(CORRELATION_ID_HEADER.lower()) + correlation_id = headers.get(REQUEST_ID_HEADER.lower()) if not correlation_id: headers = _headers_as_dict(response) - correlation_id = headers.get(CORRELATION_ID_HEADER.lower()) + correlation_id = headers.get(REQUEST_ID_HEADER.lower()) if correlation_id: # validate format @@ -143,35 +294,16 @@ def _get_correlation_id_from_headers( else: correlation_id = str(uuid4()) log.info( - f'Generated new UUID "{correlation_id}" for {CORRELATION_ID_HEADER} request header.' + f'Generated new UUID "{correlation_id}" for {REQUEST_ID_HEADER} request header.' ) return correlation_id except ValueError as e: - log.warning(f"Badly formatted {CORRELATION_ID_HEADER} received in request.") + log.warning(f"Badly formatted {REQUEST_ID_HEADER} received in request.") raise e -def _get_correlation_id_from_json_logging( - request_response: MiddlewareRequestDict | MiddlewareResponseDict, log: Logger -) -> str | None: - correlation_id: None | str - try: - correlation_id = json_logging.get_correlation_id(request_response) - # validate format - _ = uuid.UUID(correlation_id) - return correlation_id - except ValueError as e: - log.warning(f"Badly formatted {CORRELATION_ID_HEADER} received in request.") - raise e - except Exception as e: - log.debug( - f"Error received when getting {CORRELATION_ID_HEADER} header from `json_logging`. Possibly `json_logging` is not configured, and this is not an error.", - exc_info=e, - ) - - def _headers_as_dict( request_response: MiddlewareRequestDict | MiddlewareResponseDict, ): @@ -179,7 +311,6 @@ def _headers_as_dict( isinstance(request_response, dict) # pyright: ignore[reportUnnecessaryIsInstance] and "headers" in request_response.keys() ): - # FIXME does this work for a middleware _response_ as well? return { key: value for (key, value) in decode_headers(request_response["headers"]) } @@ -190,14 +321,13 @@ def _headers_as_dict( @inject def _log_all_api_requests( request: MiddlewareRequestDict, - response: MiddlewareResponseDict, app: Flask, config: Config, log: Logger, ): request_headers_safe: dict[str, str] = _headers_as_dict(request) - correlation_id = _get_correlation_id(request, response, log) + correlation_id = get_trace_id().CorrelationId if ( request_headers_safe.get(REQUEST_COOKIE_HEADER) @@ -235,16 +365,11 @@ def _log_all_api_requests( ) -def _wrap_all_api_responses( - request: MiddlewareRequestDict, - response: MiddlewareResponseDict, - config: Config, - log: Logger, -): - correlation_id = _get_correlation_id(request, response, log) +def _wrap_all_api_responses(response: MiddlewareResponseDict, config: Config): + correlation_id = get_trace_id().CorrelationId response_headers = _headers_as_dict(response) - response_headers[CORRELATION_ID_HEADER] = correlation_id + response_headers[REQUEST_ID_HEADER] = str(correlation_id) if config.web.security.csp: response_headers[CONTENT_SECURITY_POLICY_HEADER] = config.web.security.csp @@ -317,6 +442,7 @@ def encode_headers( headers.append((header.encode(encoding), value.encode(encoding))) +@final class RequestLoggerMiddleware: _app: ASGIApp @@ -342,16 +468,16 @@ async def wrapped_send(message: Any) -> None: if message["type"] != "http.response.start": return await send(message) - response = cast(MiddlewareResponseDict, scope) request = cast(MiddlewareRequestDict, scope) - _log_all_api_requests(request, response, app, config, log) + _log_all_api_requests(request, app, config, log) return await send(message) await self._app(scope, receive, wrapped_send) +@final class ResponseLoggerMiddleware: _app: ASGIApp @@ -375,83 +501,13 @@ async def wrapped_send(message: Any) -> None: response = cast(MiddlewareResponseDict, message) _log_all_api_responses(request, response, config, log) - _wrap_all_api_responses(request, response, config, log) + _wrap_all_api_responses(response, config) return await send(message) await self._app(scope, receive, wrapped_send) -class CorrelationIDMiddleware: - _app: ASGIApp - - def __init__(self, app: ASGIApp): - super().__init__() - self._app = app - - @inject - async def __call__( - self, scope: Scope, receive: Receive, send: Send, log: Logger - ) -> None: - async def wrapped_send(message: Any) -> None: - nonlocal scope - nonlocal send - - if message["type"] != "http.response.start": - return await send(message) - - request = cast(MiddlewareRequestDict, scope) - response = cast(MiddlewareResponseDict, message) - - response_headers = response["headers"] - content_type = utils.extract_content_type(response_headers) - _, encoding = utils.split_content_type(content_type) - if encoding is None: - encoding = "utf-8" - - request_headers = request["headers"] - try: - correlation_id_header_encoded = CORRELATION_ID_HEADER.lower().encode( - encoding - ) - - request_correlation_id: bytes | None = next( - ( - correlation_id - for (header, correlation_id) in request_headers - if header == correlation_id_header_encoded - ), - None, - ) - - if request_correlation_id: - # validate format - _ = uuid.UUID(request_correlation_id.decode(encoding)) - else: - request_correlation_id = str(uuid4()).encode(encoding) - request_headers.append(( - correlation_id_header_encoded, - request_correlation_id, - )) - log.info( - f'Generated new UUID "{request_correlation_id}" for {CORRELATION_ID_HEADER} request header.' - ) - - response_headers.append(( - correlation_id_header_encoded, - request_correlation_id, - )) - - return await send(message) - except ValueError as e: - log.warning( - f"Badly formatted {CORRELATION_ID_HEADER} received in request." - ) - raise e - - await self._app(scope, receive, wrapped_send) - - _DEFAULT_HOSTNAME = "localhost" _DEFAULT_PORT = 80 @@ -643,7 +699,6 @@ def register_openapi_api_request_handlers(app: FlaskApp): def register_openapi_api_response_handlers(app: FlaskApp): - app.add_middleware(CorrelationIDMiddleware) app.add_middleware(ResponseLoggerMiddleware)