Skip to content

Commit

Permalink
Add type hints to psycopg (#3067)
Browse files Browse the repository at this point in the history
* Add type hints to Psycopg

* fix tests

* fix

* Add psycopg.Connection to nitpick

* Add py.typed

* add psycopg to nitpick again

* add psycopg to nitpick again

* move py.typed to the right folder

---------

Co-authored-by: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com>
  • Loading branch information
Kludex and xrmx authored Jan 15, 2025
1 parent 52871b8 commit 07c97ea
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 34 deletions.
2 changes: 2 additions & 0 deletions docs/nitpick-exceptions.ini
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ py-class=
callable
Consumer
confluent_kafka.Message
psycopg.Connection
psycopg.AsyncConnection
ObjectProxy

any=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,28 @@
---
"""

from __future__ import annotations

import logging
import typing
from typing import Collection
from typing import Any, Callable, Collection, TypeVar

import psycopg # pylint: disable=import-self
from psycopg import (
AsyncCursor as pg_async_cursor, # pylint: disable=import-self,no-name-in-module
)
from psycopg import (
Cursor as pg_cursor, # pylint: disable=no-name-in-module,import-self
)
from psycopg.sql import Composed # pylint: disable=no-name-in-module

from opentelemetry.instrumentation import dbapi
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.psycopg.package import _instruments
from opentelemetry.instrumentation.psycopg.version import __version__
from opentelemetry.trace import TracerProvider

_logger = logging.getLogger(__name__)
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"

ConnectionT = TypeVar(
"ConnectionT", psycopg.Connection, psycopg.AsyncConnection
)
CursorT = TypeVar("CursorT", psycopg.Cursor, psycopg.AsyncCursor)


class PsycopgInstrumentor(BaseInstrumentor):
_CONNECTION_ATTRIBUTES = {
Expand All @@ -172,7 +173,7 @@ class PsycopgInstrumentor(BaseInstrumentor):
def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Any):
"""Integrate with PostgreSQL Psycopg library.
Psycopg: http://initd.org/psycopg/
"""
Expand Down Expand Up @@ -223,7 +224,7 @@ def _instrument(self, **kwargs):
enable_attribute_commenter=enable_attribute_commenter,
)

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: Any):
""" "Disable Psycopg instrumentation"""
dbapi.unwrap_connect(psycopg, "connect") # pylint: disable=no-member
dbapi.unwrap_connect(
Expand All @@ -237,7 +238,9 @@ def _uninstrument(self, **kwargs):

# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
@staticmethod
def instrument_connection(connection, tracer_provider=None):
def instrument_connection(
connection: ConnectionT, tracer_provider: TracerProvider | None = None
) -> ConnectionT:
"""Enable instrumentation in a psycopg connection.
Args:
Expand Down Expand Up @@ -269,7 +272,7 @@ def instrument_connection(connection, tracer_provider=None):

# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
@staticmethod
def uninstrument_connection(connection):
def uninstrument_connection(connection: ConnectionT) -> ConnectionT:
connection.cursor_factory = getattr(
connection, _OTEL_CURSOR_FACTORY_KEY, None
)
Expand All @@ -281,9 +284,9 @@ def uninstrument_connection(connection):
class DatabaseApiIntegration(dbapi.DatabaseApiIntegration):
def wrapped_connection(
self,
connect_method: typing.Callable[..., typing.Any],
args: typing.Tuple[typing.Any, typing.Any],
kwargs: typing.Dict[typing.Any, typing.Any],
connect_method: Callable[..., Any],
args: tuple[Any, Any],
kwargs: dict[Any, Any],
):
"""Add object proxy to connection object."""
base_cursor_factory = kwargs.pop("cursor_factory", None)
Expand All @@ -299,9 +302,9 @@ def wrapped_connection(
class DatabaseApiAsyncIntegration(dbapi.DatabaseApiIntegration):
async def wrapped_connection(
self,
connect_method: typing.Callable[..., typing.Any],
args: typing.Tuple[typing.Any, typing.Any],
kwargs: typing.Dict[typing.Any, typing.Any],
connect_method: Callable[..., Any],
args: tuple[Any, Any],
kwargs: dict[Any, Any],
):
"""Add object proxy to connection object."""
base_cursor_factory = kwargs.pop("cursor_factory", None)
Expand All @@ -317,7 +320,7 @@ async def wrapped_connection(


class CursorTracer(dbapi.CursorTracer):
def get_operation_name(self, cursor, args):
def get_operation_name(self, cursor: CursorT, args: list[Any]) -> str:
if not args:
return ""

Expand All @@ -332,7 +335,7 @@ def get_operation_name(self, cursor, args):

return ""

def get_statement(self, cursor, args):
def get_statement(self, cursor: CursorT, args: list[Any]) -> str:
if not args:
return ""

Expand All @@ -342,7 +345,11 @@ def get_statement(self, cursor, args):
return statement


def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
def _new_cursor_factory(
db_api: DatabaseApiIntegration | None = None,
base_factory: type[psycopg.Cursor] | None = None,
tracer_provider: TracerProvider | None = None,
):
if not db_api:
db_api = DatabaseApiIntegration(
__name__,
Expand All @@ -352,21 +359,21 @@ def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
tracer_provider=tracer_provider,
)

base_factory = base_factory or pg_cursor
base_factory = base_factory or psycopg.Cursor
_cursor_tracer = CursorTracer(db_api)

class TracedCursorFactory(base_factory):
def execute(self, *args, **kwargs):
def execute(self, *args: Any, **kwargs: Any):
return _cursor_tracer.traced_execution(
self, super().execute, *args, **kwargs
)

def executemany(self, *args, **kwargs):
def executemany(self, *args: Any, **kwargs: Any):
return _cursor_tracer.traced_execution(
self, super().executemany, *args, **kwargs
)

def callproc(self, *args, **kwargs):
def callproc(self, *args: Any, **kwargs: Any):
return _cursor_tracer.traced_execution(
self, super().callproc, *args, **kwargs
)
Expand All @@ -375,7 +382,9 @@ def callproc(self, *args, **kwargs):


def _new_cursor_async_factory(
db_api=None, base_factory=None, tracer_provider=None
db_api: DatabaseApiAsyncIntegration | None = None,
base_factory: type[psycopg.AsyncCursor] | None = None,
tracer_provider: TracerProvider | None = None,
):
if not db_api:
db_api = DatabaseApiAsyncIntegration(
Expand All @@ -385,21 +394,21 @@ def _new_cursor_async_factory(
version=__version__,
tracer_provider=tracer_provider,
)
base_factory = base_factory or pg_async_cursor
base_factory = base_factory or psycopg.AsyncCursor
_cursor_tracer = CursorTracer(db_api)

class TracedCursorAsyncFactory(base_factory):
async def execute(self, *args, **kwargs):
async def execute(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
self, super().execute, *args, **kwargs
)

async def executemany(self, *args, **kwargs):
async def executemany(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
self, super().executemany, *args, **kwargs
)

async def callproc(self, *args, **kwargs):
async def callproc(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
self, super().callproc, *args, **kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations


_instruments = ("psycopg >= 3.1.0",)
_instruments: tuple[str, ...] = ("psycopg >= 3.1.0",)
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ class PostgresqlIntegrationTestMixin:
def setUp(self):
super().setUp()
self.cursor_mock = mock.patch(
"opentelemetry.instrumentation.psycopg.pg_cursor", MockCursor
"opentelemetry.instrumentation.psycopg.psycopg.Cursor", MockCursor
)
self.cursor_async_mock = mock.patch(
"opentelemetry.instrumentation.psycopg.pg_async_cursor",
"opentelemetry.instrumentation.psycopg.psycopg.AsyncCursor",
MockAsyncCursor,
)
self.connection_mock = mock.patch("psycopg.connect", MockConnection)
Expand Down

0 comments on commit 07c97ea

Please sign in to comment.