Skip to content

Commit

Permalink
bpo-44822: Don't truncate strs with embedded NULL chars returned by…
Browse files Browse the repository at this point in the history
… `sqlite3` UDF callbacks (GH-27588)
  • Loading branch information
Erlend Egeberg Aasland authored Aug 5, 2021
1 parent 3e4cb7f commit 8f010dc
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
28 changes: 28 additions & 0 deletions Lib/sqlite3/test/userfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def wrapper(self, *args, **kwargs):

def func_returntext():
return "foo"
def func_returntextwithnull():
return "1\x002"
def func_returnunicode():
return "bar"
def func_returnint():
Expand Down Expand Up @@ -163,11 +165,21 @@ def step(self, val):
def finalize(self):
return self.val

class AggrText:
def __init__(self):
self.txt = ""
def step(self, txt):
self.txt = self.txt + txt
def finalize(self):
return self.txt


class FunctionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")

self.con.create_function("returntext", 0, func_returntext)
self.con.create_function("returntextwithnull", 0, func_returntextwithnull)
self.con.create_function("returnunicode", 0, func_returnunicode)
self.con.create_function("returnint", 0, func_returnint)
self.con.create_function("returnfloat", 0, func_returnfloat)
Expand Down Expand Up @@ -211,6 +223,12 @@ def test_func_return_text(self):
self.assertEqual(type(val), str)
self.assertEqual(val, "foo")

def test_func_return_text_with_null_char(self):
cur = self.con.cursor()
res = cur.execute("select returntextwithnull()").fetchone()[0]
self.assertEqual(type(res), str)
self.assertEqual(res, "1\x002")

def test_func_return_unicode(self):
cur = self.con.cursor()
cur.execute("select returnunicode()")
Expand Down Expand Up @@ -390,6 +408,7 @@ def setUp(self):
self.con.create_aggregate("checkType", 2, AggrCheckType)
self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
self.con.create_aggregate("mysum", 1, AggrSum)
self.con.create_aggregate("aggtxt", 1, AggrText)

def tearDown(self):
#self.cur.close()
Expand Down Expand Up @@ -486,6 +505,15 @@ def test_aggr_no_match(self):
val = cur.fetchone()[0]
self.assertIsNone(val)

def test_aggr_text(self):
cur = self.con.cursor()
for txt in ["foo", "1\x002"]:
with self.subTest(txt=txt):
cur.execute("select aggtxt(?) from test", (txt,))
val = cur.fetchone()[0]
self.assertEqual(val, txt)


class AuthorizerTests(unittest.TestCase):
@staticmethod
def authorizer_cb(action, arg1, arg2, dbname, source):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
:mod:`sqlite3` user-defined functions and aggregators returning
:class:`strings <str>` with embedded NUL characters are no longer
truncated. Patch by Erlend E. Aasland.
13 changes: 10 additions & 3 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,17 @@ _pysqlite_set_result(sqlite3_context* context, PyObject* py_val)
} else if (PyFloat_Check(py_val)) {
sqlite3_result_double(context, PyFloat_AsDouble(py_val));
} else if (PyUnicode_Check(py_val)) {
const char *str = PyUnicode_AsUTF8(py_val);
if (str == NULL)
Py_ssize_t sz;
const char *str = PyUnicode_AsUTF8AndSize(py_val, &sz);
if (str == NULL) {
return -1;
sqlite3_result_text(context, str, -1, SQLITE_TRANSIENT);
}
if (sz > INT_MAX) {
PyErr_SetString(PyExc_OverflowError,
"string is longer than INT_MAX bytes");
return -1;
}
sqlite3_result_text(context, str, (int)sz, SQLITE_TRANSIENT);
} else if (PyObject_CheckBuffer(py_val)) {
Py_buffer view;
if (PyObject_GetBuffer(py_val, &view, PyBUF_SIMPLE) != 0) {
Expand Down

0 comments on commit 8f010dc

Please sign in to comment.