diff --git a/tests/clients/test_mysqlclient.py b/tests/clients/test_mysqlclient.py index 0791d44ea..b8766a5d0 100644 --- a/tests/clients/test_mysqlclient.py +++ b/tests/clients/test_mysqlclient.py @@ -14,38 +14,7 @@ logger = logging.getLogger(__name__) -create_table_query = 'CREATE TABLE IF NOT EXISTS users(id serial primary key, \ - name varchar(40) NOT NULL, email varchar(40) NOT NULL)' -create_proc_query = """ -CREATE PROCEDURE test_proc(IN t VARCHAR(255)) -BEGIN - SELECT name FROM users WHERE name = t; -END -""" - -db = MySQLdb.connect(host=testenv['mysql_host'], port=testenv['mysql_port'], - user=testenv['mysql_user'], passwd=testenv['mysql_pw'], - db=testenv['mysql_db']) - -cursor = db.cursor() -cursor.execute(create_table_query) - -while cursor.nextset() is not None: - pass - -cursor.execute('DROP PROCEDURE IF EXISTS test_proc') - -while cursor.nextset() is not None: - pass - -cursor.execute(create_proc_query) - -while cursor.nextset() is not None: - pass - -cursor.close() -db.close() class TestMySQLPython(unittest.TestCase): @@ -53,6 +22,22 @@ def setUp(self): self.db = MySQLdb.connect(host=testenv['mysql_host'], port=testenv['mysql_port'], user=testenv['mysql_user'], passwd=testenv['mysql_pw'], db=testenv['mysql_db']) + database_setup_query = """ + DROP TABLE IF EXISTS users; + CREATE TABLE users(id serial primary key, \ + name varchar(40) NOT NULL, \ + email varchar(40) NOT NULL); + INSERT INTO users(name, email) VALUES('kermit', 'kermit@muppets.com'); + DROP PROCEDURE IF EXISTS test_proc; + CREATE PROCEDURE test_proc(IN t VARCHAR(255)) + BEGIN + SELECT name FROM users WHERE name = t; + END + """ + setup_cursor = self.db.cursor() + setup_cursor.execute(database_setup_query) + setup_cursor.close() + self.cursor = self.db.cursor() self.recorder = tracer.recorder self.recorder.clear_spans() @@ -60,10 +45,14 @@ def setUp(self): def tearDown(self): """ Do nothing for now """ - return None + if self.cursor and self.cursor.connection.open: + self.cursor.close() + if self.db and self.db.open: + self.db.close() def test_vanilla_query(self): - self.cursor.execute("""SELECT * from users""") + affected_rows = self.cursor.execute("""SELECT * from users""") + self.assertEqual(1, affected_rows) result = self.cursor.fetchone() self.assertEqual(3, len(result)) @@ -72,10 +61,11 @@ def test_vanilla_query(self): def test_basic_query(self): with tracer.start_active_span('test'): - result = self.cursor.execute("""SELECT * from users""") - self.cursor.fetchone() + affected_rows = self.cursor.execute("""SELECT * from users""") + result = self.cursor.fetchone() - self.assertTrue(result >= 0) + self.assertEqual(1, affected_rows) + self.assertEqual(3, len(result)) spans = self.recorder.queued_spans() self.assertEqual(2, len(spans)) @@ -97,11 +87,11 @@ def test_basic_query(self): def test_basic_insert(self): with tracer.start_active_span('test'): - result = self.cursor.execute( + affected_rows = self.cursor.execute( """INSERT INTO users(name, email) VALUES(%s, %s)""", ('beaker', 'beaker@muppets.com')) - self.assertEqual(1, result) + self.assertEqual(1, affected_rows) spans = self.recorder.queued_spans() self.assertEqual(2, len(spans)) @@ -123,11 +113,11 @@ def test_basic_insert(self): def test_executemany(self): with tracer.start_active_span('test'): - result = self.cursor.executemany("INSERT INTO users(name, email) VALUES(%s, %s)", + affected_rows = self.cursor.executemany("INSERT INTO users(name, email) VALUES(%s, %s)", [('beaker', 'beaker@muppets.com'), ('beaker', 'beaker@muppets.com')]) self.db.commit() - self.assertEqual(2, result) + self.assertEqual(2, affected_rows) spans = self.recorder.queued_spans() self.assertEqual(2, len(spans)) @@ -149,9 +139,9 @@ def test_executemany(self): def test_call_proc(self): with tracer.start_active_span('test'): - result = self.cursor.callproc('test_proc', ('beaker',)) + callproc_result = self.cursor.callproc('test_proc', ('beaker',)) - self.assertTrue(result) + self.assertTrue(callproc_result) spans = self.recorder.queued_spans() self.assertEqual(2, len(spans)) @@ -172,15 +162,14 @@ def test_call_proc(self): self.assertEqual(db_span.data["mysql"]["port"], testenv['mysql_port']) def test_error_capture(self): - result = None + affected_rows = None try: with tracer.start_active_span('test'): - result = self.cursor.execute("""SELECT * from blah""") - self.cursor.fetchone() + affected_rows = self.cursor.execute("""SELECT * from blah""") except Exception: pass - self.assertIsNone(result) + self.assertIsNone(affected_rows) spans = self.recorder.queued_spans() self.assertEqual(2, len(spans))