Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TracedConnectionProxy #1097

Merged
merged 2 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def uninstrument_connection(connection):
Returns:
An uninstrumented connection.
"""
if isinstance(connection, wrapt.ObjectProxy):
return connection.__wrapped__
if isinstance(connection, _TracedConnectionProxy):
return connection._connection

_logger.warning("Connection is not instrumented")
return connection
Expand Down Expand Up @@ -300,28 +300,35 @@ def get_connection_attributes(self, connection):
self.span_attributes[SpanAttributes.NET_PEER_PORT] = port


class _TracedConnectionProxy:
pass


def get_traced_connection_proxy(
connection, db_api_integration, *args, **kwargs
):
# pylint: disable=abstract-method
class TracedConnectionProxy(wrapt.ObjectProxy):
# pylint: disable=unused-argument
def __init__(self, connection, *args, **kwargs):
wrapt.ObjectProxy.__init__(self, connection)
class TracedConnectionProxy(type(connection), _TracedConnectionProxy):
def __init__(self, connection):
self._connection = connection

def __getattr__(self, name):
return object.__getattribute__(
object.__getattribute__(self, "_connection"), name
)

def cursor(self, *args, **kwargs):
return get_traced_cursor_proxy(
self.__wrapped__.cursor(*args, **kwargs), db_api_integration
self._connection.cursor(*args, **kwargs), db_api_integration
)

def __enter__(self):
self.__wrapped__.__enter__()
return self

def __exit__(self, *args, **kwargs):
self.__wrapped__.__exit__(*args, **kwargs)
# For some reason this is necessary as trying to access the close
# method of self._connection via __getattr__ leads to unexplained
# errors.
def close(self):
self._connection.close()

return TracedConnectionProxy(connection, *args, **kwargs)
return TracedConnectionProxy(connection)


class CursorTracer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,14 @@ def test_callproc(self):

@mock.patch("opentelemetry.instrumentation.dbapi")
def test_wrap_connect(self, mock_dbapi):
dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-")
dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-")
connection = mock_dbapi.connect()
self.assertEqual(mock_dbapi.connect.call_count, 1)
self.assertIsInstance(connection.__wrapped__, mock.Mock)
self.assertIsInstance(connection._connection, mock.Mock)

@mock.patch("opentelemetry.instrumentation.dbapi")
def test_unwrap_connect(self, mock_dbapi):
dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-")
dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-")
connection = mock_dbapi.connect()
self.assertEqual(mock_dbapi.connect.call_count, 1)

Expand All @@ -279,19 +279,21 @@ def test_unwrap_connect(self, mock_dbapi):
self.assertIsInstance(connection, mock.Mock)

def test_instrument_connection(self):
connection = mock.Mock()
connection = MockConnectionEmpty()
# Avoid get_attributes failing because can't concatenate mock
# pylint: disable=attribute-defined-outside-init
connection.database = "-"
connection2 = dbapi.instrument_connection(self.tracer, connection, "-")
self.assertIs(connection2.__wrapped__, connection)
self.assertIs(connection2._connection, connection)

def test_uninstrument_connection(self):
connection = mock.Mock()
connection = MockConnectionEmpty()
# Set connection.database to avoid a failure because mock can't
# be concatenated
# pylint: disable=attribute-defined-outside-init
connection.database = "-"
connection2 = dbapi.instrument_connection(self.tracer, connection, "-")
self.assertIs(connection2.__wrapped__, connection)
self.assertIs(connection2._connection, connection)

connection3 = dbapi.uninstrument_connection(connection2)
self.assertIs(connection3, connection)
Expand All @@ -307,10 +309,12 @@ def mock_connect(*args, **kwargs):
server_host = kwargs.get("server_host")
server_port = kwargs.get("server_port")
user = kwargs.get("user")
return MockConnection(database, server_port, server_host, user)
return MockConnectionWithAttributes(
database, server_port, server_host, user
)


class MockConnection:
class MockConnectionWithAttributes:
def __init__(self, database, server_port, server_host, user):
self.database = database
self.server_port = server_port
Expand Down Expand Up @@ -343,3 +347,7 @@ def executemany(self, query, params=None, throw_exception=False):
def callproc(self, query, params=None, throw_exception=False):
if throw_exception:
raise Exception("Test Exception")


class MockConnectionEmpty:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock
from unittest.mock import Mock, patch

import mysql.connector

Expand All @@ -22,15 +22,24 @@
from opentelemetry.test.test_base import TestBase


def mock_connect(*args, **kwargs):
class MockConnection:
def cursor(self):
# pylint: disable=no-self-use
return Mock()

return MockConnection()


class TestMysqlIntegration(TestBase):
def tearDown(self):
super().tearDown()
with self.disable_logging():
MySQLInstrumentor().uninstrument()

@mock.patch("mysql.connector.connect")
@patch("mysql.connector.connect", new=mock_connect)
# pylint: disable=unused-argument
def test_instrumentor(self, mock_connect):
def test_instrumentor(self):
MySQLInstrumentor().instrument()

cnx = mysql.connector.connect(database="test")
Expand Down Expand Up @@ -58,9 +67,8 @@ def test_instrumentor(self, mock_connect):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@mock.patch("mysql.connector.connect")
# pylint: disable=unused-argument
def test_custom_tracer_provider(self, mock_connect):
@patch("mysql.connector.connect", new=mock_connect)
def test_custom_tracer_provider(self):
resource = resources.Resource.create({})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
Expand All @@ -77,9 +85,9 @@ def test_custom_tracer_provider(self, mock_connect):

self.assertIs(span.resource, resource)

@mock.patch("mysql.connector.connect")
@patch("mysql.connector.connect", new=mock_connect)
# pylint: disable=unused-argument
def test_instrument_connection(self, mock_connect):
def test_instrument_connection(self):
cnx = mysql.connector.connect(database="test")
query = "SELECT * FROM test"
cursor = cnx.cursor()
Expand All @@ -95,9 +103,9 @@ def test_instrument_connection(self, mock_connect):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@mock.patch("mysql.connector.connect")
@patch("mysql.connector.connect", new=mock_connect)
# pylint: disable=unused-argument
def test_uninstrument_connection(self, mock_connect):
def test_uninstrument_connection(self):
MySQLInstrumentor().instrument()
cnx = mysql.connector.connect(database="test")
query = "SELECT * FROM test"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock
from unittest.mock import Mock, patch

import pymysql

Expand All @@ -22,15 +22,24 @@
from opentelemetry.test.test_base import TestBase


def mock_connect(*args, **kwargs):
class MockConnection:
def cursor(self):
# pylint: disable=no-self-use
return Mock()

return MockConnection()


class TestPyMysqlIntegration(TestBase):
def tearDown(self):
super().tearDown()
with self.disable_logging():
PyMySQLInstrumentor().uninstrument()

@mock.patch("pymysql.connect")
@patch("pymysql.connect", new=mock_connect)
# pylint: disable=unused-argument
def test_instrumentor(self, mock_connect):
def test_instrumentor(self):
PyMySQLInstrumentor().instrument()

cnx = pymysql.connect(database="test")
Expand Down Expand Up @@ -58,9 +67,9 @@ def test_instrumentor(self, mock_connect):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@mock.patch("pymysql.connect")
@patch("pymysql.connect", new=mock_connect)
# pylint: disable=unused-argument
def test_custom_tracer_provider(self, mock_connect):
def test_custom_tracer_provider(self):
resource = resources.Resource.create({})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
Expand All @@ -78,9 +87,9 @@ def test_custom_tracer_provider(self, mock_connect):

self.assertIs(span.resource, resource)

@mock.patch("pymysql.connect")
@patch("pymysql.connect", new=mock_connect)
# pylint: disable=unused-argument
def test_instrument_connection(self, mock_connect):
def test_instrument_connection(self):
cnx = pymysql.connect(database="test")
query = "SELECT * FROM test"
cursor = cnx.cursor()
Expand All @@ -96,9 +105,9 @@ def test_instrument_connection(self, mock_connect):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@mock.patch("pymysql.connect")
@patch("pymysql.connect", new=mock_connect)
# pylint: disable=unused-argument
def test_uninstrument_connection(self, mock_connect):
def test_uninstrument_connection(self):
PyMySQLInstrumentor().instrument()
cnx = pymysql.connect(database="test")
query = "SELECT * FROM test"
Expand Down