diff --git a/logfire-api/logfire_api/_internal/integrations/starlette.pyi b/logfire-api/logfire_api/_internal/integrations/starlette.pyi index bc93bdd3b..5270d92b0 100644 --- a/logfire-api/logfire_api/_internal/integrations/starlette.pyi +++ b/logfire-api/logfire_api/_internal/integrations/starlette.pyi @@ -3,14 +3,9 @@ from logfire._internal.integrations.asgi import tweak_asgi_spans_tracer_provider from logfire._internal.utils import maybe_capture_server_headers as maybe_capture_server_headers from opentelemetry.instrumentation.asgi.types import ClientRequestHook, ClientResponseHook, ServerRequestHook from starlette.applications import Starlette -from typing_extensions import TypedDict, Unpack +from typing import Any -class StarletteInstrumentKwargs(TypedDict, total=False): - server_request_hook: ServerRequestHook | None - client_request_hook: ClientRequestHook | None - client_response_hook: ClientResponseHook | None - -def instrument_starlette(logfire_instance: Logfire, app: Starlette, *, record_send_receive: bool = False, capture_headers: bool = False, **kwargs: Unpack[StarletteInstrumentKwargs]): +def instrument_starlette(logfire_instance: Logfire, app: Starlette, *, record_send_receive: bool = False, capture_headers: bool = False, server_request_hook: ServerRequestHook | None = None, client_request_hook: ClientRequestHook | None = None, client_response_hook: ClientResponseHook | None = None, **kwargs: Any): """Instrument `app` so that spans are automatically created for each request. See the `Logfire.instrument_starlette` method for details. diff --git a/logfire-api/logfire_api/_internal/json_encoder.pyi b/logfire-api/logfire_api/_internal/json_encoder.pyi index 491e1704e..e694f4fcd 100644 --- a/logfire-api/logfire_api/_internal/json_encoder.pyi +++ b/logfire-api/logfire_api/_internal/json_encoder.pyi @@ -9,4 +9,4 @@ def encoder_by_type() -> dict[type[Any], EncoderFunction]: ... def to_json_value(o: Any, seen: set[int]) -> JsonValue: ... def logfire_json_dumps(obj: Any) -> str: ... def is_sqlalchemy(obj: Any) -> bool: ... -def is_attrs(obj: Any) -> bool: ... +def is_attrs(cls) -> bool: ... diff --git a/logfire-api/logfire_api/_internal/main.pyi b/logfire-api/logfire_api/_internal/main.pyi index ffa30bd90..d106df7b2 100644 --- a/logfire-api/logfire_api/_internal/main.pyi +++ b/logfire-api/logfire_api/_internal/main.pyi @@ -21,7 +21,6 @@ from .integrations.pymongo import PymongoInstrumentKwargs as PymongoInstrumentKw from .integrations.redis import RedisInstrumentKwargs as RedisInstrumentKwargs from .integrations.sqlalchemy import SQLAlchemyInstrumentKwargs as SQLAlchemyInstrumentKwargs from .integrations.sqlite3 import SQLite3Connection as SQLite3Connection, SQLite3InstrumentKwargs as SQLite3InstrumentKwargs -from .integrations.starlette import StarletteInstrumentKwargs as StarletteInstrumentKwargs from .integrations.system_metrics import Base as SystemMetricsBase, Config as SystemMetricsConfig from .integrations.wsgi import WSGIInstrumentKwargs as WSGIInstrumentKwargs from .json_encoder import logfire_json_dumps as logfire_json_dumps @@ -34,6 +33,7 @@ from django.http import HttpRequest as HttpRequest, HttpResponse as HttpResponse from fastapi import FastAPI from flask.app import Flask from opentelemetry.context import Context as Context +from opentelemetry.instrumentation.asgi.types import ClientRequestHook, ClientResponseHook, ServerRequestHook from opentelemetry.metrics import CallbackT as CallbackT, Counter, Histogram, UpDownCounter, _Gauge as Gauge from opentelemetry.sdk.trace import ReadableSpan, Span from opentelemetry.trace import SpanContext, Tracer @@ -643,7 +643,7 @@ class Logfire: response_hook: A function called right before a span is finished for the response. **kwargs: Additional keyword arguments to pass to the OpenTelemetry Flask instrumentation. """ - def instrument_starlette(self, app: Starlette, *, capture_headers: bool = False, record_send_receive: bool = False, **kwargs: Unpack[StarletteInstrumentKwargs]) -> None: + def instrument_starlette(self, app: Starlette, *, capture_headers: bool = False, record_send_receive: bool = False, server_request_hook: ServerRequestHook | None = None, client_request_hook: ClientRequestHook | None = None, client_response_hook: ClientResponseHook | None = None, **kwargs: Any) -> None: """Instrument `app` so that spans are automatically created for each request. Uses the @@ -658,6 +658,9 @@ class Logfire: These are disabled by default to reduce overhead and the number of spans created, since many can be created for a single request, and they are not often useful. If enabled, they will be set to debug level, meaning they will usually still be hidden in the UI. + server_request_hook: A function that receives a server span and the ASGI scope for every incoming request. + client_request_hook: A function that receives a span, the ASGI scope and the receive ASGI message for every ASGI receive event. + client_response_hook: A function that receives a span, the ASGI scope and the send ASGI message for every ASGI send event. **kwargs: Additional keyword arguments to pass to the OpenTelemetry Starlette instrumentation. """ def instrument_asgi(self, app: ASGIApp, capture_headers: bool = False, record_send_receive: bool = False, **kwargs: Unpack[ASGIInstrumentKwargs]) -> ASGIApp: diff --git a/logfire/_internal/integrations/starlette.py b/logfire/_internal/integrations/starlette.py index 9ca69212c..5c7bcccfd 100644 --- a/logfire/_internal/integrations/starlette.py +++ b/logfire/_internal/integrations/starlette.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Any from starlette.applications import Starlette try: + from opentelemetry.instrumentation.asgi.types import ClientRequestHook, ClientResponseHook, ServerRequestHook from opentelemetry.instrumentation.starlette import StarletteInstrumentor except ImportError: raise RuntimeError( @@ -17,15 +18,6 @@ from logfire._internal.integrations.asgi import tweak_asgi_spans_tracer_provider from logfire._internal.utils import maybe_capture_server_headers -if TYPE_CHECKING: - from opentelemetry.instrumentation.asgi.types import ClientRequestHook, ClientResponseHook, ServerRequestHook - from typing_extensions import TypedDict, Unpack - - class StarletteInstrumentKwargs(TypedDict, total=False): - server_request_hook: ServerRequestHook | None - client_request_hook: ClientRequestHook | None - client_response_hook: ClientResponseHook | None - def instrument_starlette( logfire_instance: Logfire, @@ -33,7 +25,10 @@ def instrument_starlette( *, record_send_receive: bool = False, capture_headers: bool = False, - **kwargs: Unpack[StarletteInstrumentKwargs], + server_request_hook: ServerRequestHook | None = None, + client_request_hook: ClientRequestHook | None = None, + client_response_hook: ClientResponseHook | None = None, + **kwargs: Any, ): """Instrument `app` so that spans are automatically created for each request. @@ -42,6 +37,9 @@ def instrument_starlette( maybe_capture_server_headers(capture_headers) StarletteInstrumentor().instrument_app( app, + server_request_hook=server_request_hook, + client_request_hook=client_request_hook, + client_response_hook=client_response_hook, **{ # type: ignore 'tracer_provider': tweak_asgi_spans_tracer_provider(logfire_instance, record_send_receive), 'meter_provider': logfire_instance.config.get_meter_provider(), diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index abed9b18d..beba2070c 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -74,6 +74,7 @@ from django.http import HttpRequest, HttpResponse from fastapi import FastAPI from flask.app import Flask + from opentelemetry.instrumentation.asgi.types import ClientRequestHook, ClientResponseHook, ServerRequestHook from opentelemetry.metrics import _Gauge as Gauge from sqlalchemy import Engine from sqlalchemy.ext.asyncio import AsyncEngine @@ -100,7 +101,6 @@ from .integrations.redis import RedisInstrumentKwargs from .integrations.sqlalchemy import SQLAlchemyInstrumentKwargs from .integrations.sqlite3 import SQLite3Connection - from .integrations.starlette import StarletteInstrumentKwargs from .integrations.system_metrics import Base as SystemMetricsBase, Config as SystemMetricsConfig from .utils import SysExcInfo @@ -1452,7 +1452,10 @@ def instrument_starlette( *, capture_headers: bool = False, record_send_receive: bool = False, - **kwargs: Unpack[StarletteInstrumentKwargs], + server_request_hook: ServerRequestHook | None = None, + client_request_hook: ClientRequestHook | None = None, + client_response_hook: ClientResponseHook | None = None, + **kwargs: Any, ) -> None: """Instrument `app` so that spans are automatically created for each request. @@ -1468,6 +1471,9 @@ def instrument_starlette( These are disabled by default to reduce overhead and the number of spans created, since many can be created for a single request, and they are not often useful. If enabled, they will be set to debug level, meaning they will usually still be hidden in the UI. + server_request_hook: A function that receives a server span and the ASGI scope for every incoming request. + client_request_hook: A function that receives a span, the ASGI scope and the receive ASGI message for every ASGI receive event. + client_response_hook: A function that receives a span, the ASGI scope and the send ASGI message for every ASGI send event. **kwargs: Additional keyword arguments to pass to the OpenTelemetry Starlette instrumentation. """ from .integrations.starlette import instrument_starlette @@ -1478,6 +1484,9 @@ def instrument_starlette( app, record_send_receive=record_send_receive, capture_headers=capture_headers, + server_request_hook=server_request_hook, + client_request_hook=client_request_hook, + client_response_hook=client_response_hook, **kwargs, )