Skip to content

Commit

Permalink
pythongh-129928: Raise more accurate exception for sqlite3 UDF creati…
Browse files Browse the repository at this point in the history
…on misuse

Consistently raise ProgrammingError if the user tries to create an UDF
with an invalid number of parameters.
  • Loading branch information
erlend-aasland committed Feb 10, 2025
1 parent 7e6ee50 commit 8424b44
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
9 changes: 4 additions & 5 deletions Lib/test/test_sqlite3/test_userfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def tearDown(self):
self.con.close()

def test_func_error_on_create(self):
with self.assertRaises(sqlite.OperationalError):
with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"):
self.con.create_function("bla", -100, lambda x: 2*x)

def test_func_too_many_args(self):
Expand Down Expand Up @@ -507,9 +507,8 @@ def test_win_sum_int(self):
self.assertEqual(self.cur.fetchall(), self.expected)

def test_win_error_on_create(self):
self.assertRaises(sqlite.ProgrammingError,
self.con.create_window_function,
"shouldfail", -100, WindowSumInt)
with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"):
self.con.create_window_function("shouldfail", -100, WindowSumInt)

@with_tracebacks(BadWindow)
def test_win_exception_in_method(self):
Expand Down Expand Up @@ -638,7 +637,7 @@ def tearDown(self):
self.con.close()

def test_aggr_error_on_create(self):
with self.assertRaises(sqlite.OperationalError):
with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"):
self.con.create_function("bla", -100, AggrSum)

@with_tracebacks(AttributeError, msg_regex="AggrNoStep")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Raise :exc:`sqlite3.ProgrammingError` if a user-defined SQL function with
invalid number of parameters is created. Patch by Erlend Aasland.
24 changes: 23 additions & 1 deletion Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,20 @@ destructor_callback(void *ctx)
}
}

static int
check_num_params(pysqlite_Connection *self, const int n, const char *name)
{
int limit = sqlite3_limit(self->db, SQLITE_LIMIT_FUNCTION_ARG, -1);
assert(limit >= 0);
if (n < -1 || n > limit) {
PyErr_Format(self->ProgrammingError,
"'%s' must be between -1 and %d, not %d",
name, limit, n);
return -1;
}
return 0;
}

/*[clinic input]
_sqlite3.Connection.create_function as pysqlite_connection_create_function
Expand Down Expand Up @@ -1167,6 +1181,9 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
return NULL;
}
if (check_num_params(self, narg, "narg") < 0) {
return NULL;
}

if (deterministic) {
flags |= SQLITE_DETERMINISTIC;
Expand Down Expand Up @@ -1307,10 +1324,12 @@ create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls,
"SQLite 3.25.0 or higher");
return NULL;
}

if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
return NULL;
}
if (check_num_params(self, num_params, "num_params") < 0) {
return NULL;
}

int flags = SQLITE_UTF8;
int rc;
Expand Down Expand Up @@ -1367,6 +1386,9 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
return NULL;
}
if (check_num_params(self, n_arg, "n_arg") < 0) {
return NULL;
}

callback_context *ctx = create_callback_context(cls, aggregate_class);
if (ctx == NULL) {
Expand Down

0 comments on commit 8424b44

Please sign in to comment.