diff --git a/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py b/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py index 8f85349edf..2559466221 100644 --- a/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-dbapi/src/opentelemetry/instrumentation/dbapi/__init__.py @@ -214,8 +214,8 @@ def uninstrument_connection(connection): Returns: An uninstrumented connection. """ - if isinstance(connection, wrapt.ObjectProxy): - return connection.__wrapped__ + if isinstance(connection, _TracedConnectionProxy): + return connection._connection _logger.warning("Connection is not instrumented") return connection @@ -300,28 +300,35 @@ def get_connection_attributes(self, connection): self.span_attributes[SpanAttributes.NET_PEER_PORT] = port +class _TracedConnectionProxy: + pass + + def get_traced_connection_proxy( connection, db_api_integration, *args, **kwargs ): # pylint: disable=abstract-method - class TracedConnectionProxy(wrapt.ObjectProxy): - # pylint: disable=unused-argument - def __init__(self, connection, *args, **kwargs): - wrapt.ObjectProxy.__init__(self, connection) + class TracedConnectionProxy(type(connection), _TracedConnectionProxy): + def __init__(self, connection): + self._connection = connection + + def __getattr__(self, name): + return object.__getattribute__( + object.__getattribute__(self, "_connection"), name + ) def cursor(self, *args, **kwargs): return get_traced_cursor_proxy( - self.__wrapped__.cursor(*args, **kwargs), db_api_integration + self._connection.cursor(*args, **kwargs), db_api_integration ) - def __enter__(self): - self.__wrapped__.__enter__() - return self - - def __exit__(self, *args, **kwargs): - self.__wrapped__.__exit__(*args, **kwargs) + # For some reason this is necessary as trying to access the close + # method of self._connection via __getattr__ leads to unexplained + # errors. + def close(self): + self._connection.close() - return TracedConnectionProxy(connection, *args, **kwargs) + return TracedConnectionProxy(connection) class CursorTracer: diff --git a/instrumentation/opentelemetry-instrumentation-dbapi/tests/test_dbapi_integration.py b/instrumentation/opentelemetry-instrumentation-dbapi/tests/test_dbapi_integration.py index 0302824db4..2644cb2fcd 100644 --- a/instrumentation/opentelemetry-instrumentation-dbapi/tests/test_dbapi_integration.py +++ b/instrumentation/opentelemetry-instrumentation-dbapi/tests/test_dbapi_integration.py @@ -262,14 +262,14 @@ def test_callproc(self): @mock.patch("opentelemetry.instrumentation.dbapi") def test_wrap_connect(self, mock_dbapi): - dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-") + dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-") connection = mock_dbapi.connect() self.assertEqual(mock_dbapi.connect.call_count, 1) - self.assertIsInstance(connection.__wrapped__, mock.Mock) + self.assertIsInstance(connection._connection, mock.Mock) @mock.patch("opentelemetry.instrumentation.dbapi") def test_unwrap_connect(self, mock_dbapi): - dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-") + dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-") connection = mock_dbapi.connect() self.assertEqual(mock_dbapi.connect.call_count, 1) @@ -279,19 +279,21 @@ def test_unwrap_connect(self, mock_dbapi): self.assertIsInstance(connection, mock.Mock) def test_instrument_connection(self): - connection = mock.Mock() + connection = MockConnectionEmpty() # Avoid get_attributes failing because can't concatenate mock + # pylint: disable=attribute-defined-outside-init connection.database = "-" connection2 = dbapi.instrument_connection(self.tracer, connection, "-") - self.assertIs(connection2.__wrapped__, connection) + self.assertIs(connection2._connection, connection) def test_uninstrument_connection(self): - connection = mock.Mock() + connection = MockConnectionEmpty() # Set connection.database to avoid a failure because mock can't # be concatenated + # pylint: disable=attribute-defined-outside-init connection.database = "-" connection2 = dbapi.instrument_connection(self.tracer, connection, "-") - self.assertIs(connection2.__wrapped__, connection) + self.assertIs(connection2._connection, connection) connection3 = dbapi.uninstrument_connection(connection2) self.assertIs(connection3, connection) @@ -307,10 +309,12 @@ def mock_connect(*args, **kwargs): server_host = kwargs.get("server_host") server_port = kwargs.get("server_port") user = kwargs.get("user") - return MockConnection(database, server_port, server_host, user) + return MockConnectionWithAttributes( + database, server_port, server_host, user + ) -class MockConnection: +class MockConnectionWithAttributes: def __init__(self, database, server_port, server_host, user): self.database = database self.server_port = server_port @@ -343,3 +347,7 @@ def executemany(self, query, params=None, throw_exception=False): def callproc(self, query, params=None, throw_exception=False): if throw_exception: raise Exception("Test Exception") + + +class MockConnectionEmpty: + pass diff --git a/instrumentation/opentelemetry-instrumentation-mysql/tests/test_mysql_integration.py b/instrumentation/opentelemetry-instrumentation-mysql/tests/test_mysql_integration.py index fc45b72b46..e2a0f2057c 100644 --- a/instrumentation/opentelemetry-instrumentation-mysql/tests/test_mysql_integration.py +++ b/instrumentation/opentelemetry-instrumentation-mysql/tests/test_mysql_integration.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock +from unittest.mock import Mock, patch import mysql.connector @@ -22,15 +22,24 @@ from opentelemetry.test.test_base import TestBase +def mock_connect(*args, **kwargs): + class MockConnection: + def cursor(self): + # pylint: disable=no-self-use + return Mock() + + return MockConnection() + + class TestMysqlIntegration(TestBase): def tearDown(self): super().tearDown() with self.disable_logging(): MySQLInstrumentor().uninstrument() - @mock.patch("mysql.connector.connect") + @patch("mysql.connector.connect", new=mock_connect) # pylint: disable=unused-argument - def test_instrumentor(self, mock_connect): + def test_instrumentor(self): MySQLInstrumentor().instrument() cnx = mysql.connector.connect(database="test") @@ -58,9 +67,8 @@ def test_instrumentor(self, mock_connect): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) - @mock.patch("mysql.connector.connect") - # pylint: disable=unused-argument - def test_custom_tracer_provider(self, mock_connect): + @patch("mysql.connector.connect", new=mock_connect) + def test_custom_tracer_provider(self): resource = resources.Resource.create({}) result = self.create_tracer_provider(resource=resource) tracer_provider, exporter = result @@ -77,9 +85,9 @@ def test_custom_tracer_provider(self, mock_connect): self.assertIs(span.resource, resource) - @mock.patch("mysql.connector.connect") + @patch("mysql.connector.connect", new=mock_connect) # pylint: disable=unused-argument - def test_instrument_connection(self, mock_connect): + def test_instrument_connection(self): cnx = mysql.connector.connect(database="test") query = "SELECT * FROM test" cursor = cnx.cursor() @@ -95,9 +103,9 @@ def test_instrument_connection(self, mock_connect): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) - @mock.patch("mysql.connector.connect") + @patch("mysql.connector.connect", new=mock_connect) # pylint: disable=unused-argument - def test_uninstrument_connection(self, mock_connect): + def test_uninstrument_connection(self): MySQLInstrumentor().instrument() cnx = mysql.connector.connect(database="test") query = "SELECT * FROM test" diff --git a/instrumentation/opentelemetry-instrumentation-pymysql/tests/test_pymysql_integration.py b/instrumentation/opentelemetry-instrumentation-pymysql/tests/test_pymysql_integration.py index 587ebc1b53..42dd94f2da 100644 --- a/instrumentation/opentelemetry-instrumentation-pymysql/tests/test_pymysql_integration.py +++ b/instrumentation/opentelemetry-instrumentation-pymysql/tests/test_pymysql_integration.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock +from unittest.mock import Mock, patch import pymysql @@ -22,15 +22,24 @@ from opentelemetry.test.test_base import TestBase +def mock_connect(*args, **kwargs): + class MockConnection: + def cursor(self): + # pylint: disable=no-self-use + return Mock() + + return MockConnection() + + class TestPyMysqlIntegration(TestBase): def tearDown(self): super().tearDown() with self.disable_logging(): PyMySQLInstrumentor().uninstrument() - @mock.patch("pymysql.connect") + @patch("pymysql.connect", new=mock_connect) # pylint: disable=unused-argument - def test_instrumentor(self, mock_connect): + def test_instrumentor(self): PyMySQLInstrumentor().instrument() cnx = pymysql.connect(database="test") @@ -58,9 +67,9 @@ def test_instrumentor(self, mock_connect): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) - @mock.patch("pymysql.connect") + @patch("pymysql.connect", new=mock_connect) # pylint: disable=unused-argument - def test_custom_tracer_provider(self, mock_connect): + def test_custom_tracer_provider(self): resource = resources.Resource.create({}) result = self.create_tracer_provider(resource=resource) tracer_provider, exporter = result @@ -78,9 +87,9 @@ def test_custom_tracer_provider(self, mock_connect): self.assertIs(span.resource, resource) - @mock.patch("pymysql.connect") + @patch("pymysql.connect", new=mock_connect) # pylint: disable=unused-argument - def test_instrument_connection(self, mock_connect): + def test_instrument_connection(self): cnx = pymysql.connect(database="test") query = "SELECT * FROM test" cursor = cnx.cursor() @@ -96,9 +105,9 @@ def test_instrument_connection(self, mock_connect): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) - @mock.patch("pymysql.connect") + @patch("pymysql.connect", new=mock_connect) # pylint: disable=unused-argument - def test_uninstrument_connection(self, mock_connect): + def test_uninstrument_connection(self): PyMySQLInstrumentor().instrument() cnx = pymysql.connect(database="test") query = "SELECT * FROM test"