Skip to content

Commit

Permalink
test: Fix mysqlclient to make database content predicatable
Browse files Browse the repository at this point in the history
Signed-off-by: Ferenc Géczi <ferenc.geczi@ibm.com>
  • Loading branch information
Ferenc- committed Sep 4, 2023
1 parent b1cea7e commit 5f75f21
Showing 1 changed file with 35 additions and 46 deletions.
81 changes: 35 additions & 46 deletions tests/clients/test_mysqlclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,45 @@

logger = logging.getLogger(__name__)

create_table_query = 'CREATE TABLE IF NOT EXISTS users(id serial primary key, \
name varchar(40) NOT NULL, email varchar(40) NOT NULL)'

create_proc_query = """
CREATE PROCEDURE test_proc(IN t VARCHAR(255))
BEGIN
SELECT name FROM users WHERE name = t;
END
"""

db = MySQLdb.connect(host=testenv['mysql_host'], port=testenv['mysql_port'],
user=testenv['mysql_user'], passwd=testenv['mysql_pw'],
db=testenv['mysql_db'])

cursor = db.cursor()
cursor.execute(create_table_query)

while cursor.nextset() is not None:
pass

cursor.execute('DROP PROCEDURE IF EXISTS test_proc')

while cursor.nextset() is not None:
pass

cursor.execute(create_proc_query)

while cursor.nextset() is not None:
pass

cursor.close()
db.close()


class TestMySQLPython(unittest.TestCase):
def setUp(self):
self.db = MySQLdb.connect(host=testenv['mysql_host'], port=testenv['mysql_port'],
user=testenv['mysql_user'], passwd=testenv['mysql_pw'],
db=testenv['mysql_db'])
database_setup_query = """
DROP TABLE IF EXISTS users;
CREATE TABLE users(id serial primary key, \
name varchar(40) NOT NULL, \
email varchar(40) NOT NULL);
INSERT INTO users(name, email) VALUES('kermit', 'kermit@muppets.com');
DROP PROCEDURE IF EXISTS test_proc;
CREATE PROCEDURE test_proc(IN t VARCHAR(255))
BEGIN
SELECT name FROM users WHERE name = t;
END
"""
setup_cursor = self.db.cursor()
setup_cursor.execute(database_setup_query)
setup_cursor.close()

self.cursor = self.db.cursor()
self.recorder = tracer.recorder
self.recorder.clear_spans()
tracer.cur_ctx = None

def tearDown(self):
""" Do nothing for now """
return None
if self.cursor and self.cursor.connection.open:
self.cursor.close()
if self.db and self.db.open:
self.db.close()

def test_vanilla_query(self):
self.cursor.execute("""SELECT * from users""")
affected_rows = self.cursor.execute("""SELECT * from users""")
self.assertEqual(1, affected_rows)
result = self.cursor.fetchone()
self.assertEqual(3, len(result))

Expand All @@ -72,10 +61,11 @@ def test_vanilla_query(self):

def test_basic_query(self):
with tracer.start_active_span('test'):
result = self.cursor.execute("""SELECT * from users""")
self.cursor.fetchone()
affected_rows = self.cursor.execute("""SELECT * from users""")
result = self.cursor.fetchone()

self.assertTrue(result >= 0)
self.assertEqual(1, affected_rows)
self.assertEqual(3, len(result))

spans = self.recorder.queued_spans()
self.assertEqual(2, len(spans))
Expand All @@ -97,11 +87,11 @@ def test_basic_query(self):

def test_basic_insert(self):
with tracer.start_active_span('test'):
result = self.cursor.execute(
affected_rows = self.cursor.execute(
"""INSERT INTO users(name, email) VALUES(%s, %s)""",
('beaker', 'beaker@muppets.com'))

self.assertEqual(1, result)
self.assertEqual(1, affected_rows)

spans = self.recorder.queued_spans()
self.assertEqual(2, len(spans))
Expand All @@ -123,11 +113,11 @@ def test_basic_insert(self):

def test_executemany(self):
with tracer.start_active_span('test'):
result = self.cursor.executemany("INSERT INTO users(name, email) VALUES(%s, %s)",
affected_rows = self.cursor.executemany("INSERT INTO users(name, email) VALUES(%s, %s)",
[('beaker', 'beaker@muppets.com'), ('beaker', 'beaker@muppets.com')])
self.db.commit()

self.assertEqual(2, result)
self.assertEqual(2, affected_rows)

spans = self.recorder.queued_spans()
self.assertEqual(2, len(spans))
Expand All @@ -149,9 +139,9 @@ def test_executemany(self):

def test_call_proc(self):
with tracer.start_active_span('test'):
result = self.cursor.callproc('test_proc', ('beaker',))
callproc_result = self.cursor.callproc('test_proc', ('beaker',))

self.assertTrue(result)
self.assertTrue(callproc_result)

spans = self.recorder.queued_spans()
self.assertEqual(2, len(spans))
Expand All @@ -172,15 +162,14 @@ def test_call_proc(self):
self.assertEqual(db_span.data["mysql"]["port"], testenv['mysql_port'])

def test_error_capture(self):
result = None
affected_rows = None
try:
with tracer.start_active_span('test'):
result = self.cursor.execute("""SELECT * from blah""")
self.cursor.fetchone()
affected_rows = self.cursor.execute("""SELECT * from blah""")
except Exception:
pass

self.assertIsNone(result)
self.assertIsNone(affected_rows)

spans = self.recorder.queued_spans()
self.assertEqual(2, len(spans))
Expand Down

0 comments on commit 5f75f21

Please sign in to comment.