Skip to content

Commit

Permalink
nb_type_vectorcall(): avoid the use of alloca()
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Sep 3, 2024
1 parent ec276ad commit e24d7f3
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions src/nb_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,8 @@ static PyMethodDef class_getitem_method[] = {
};
#endif

// Implements the vector call protocol directly on a object
// to construct instances more efficiently.
// Implements the vector call protocol directly on type objects to construct
// instances more efficiently.
static PyObject *nb_type_vectorcall(PyObject *self, PyObject *const *args_in,
size_t nargsf,
PyObject *kwargs_in) noexcept {
Expand All @@ -883,23 +883,41 @@ static PyObject *nb_type_vectorcall(PyObject *self, PyObject *const *args_in,
self = inst_new_int(tp, nullptr, nullptr);
if (!self)
return nullptr;
} else if (nargs == 0 && !kwargs_in) {
if (nb_func_data(func)->nargs != 0) // fail
return func->vectorcall((PyObject *) func, nullptr, 0, nullptr);
} else if (nargs == 0 && !kwargs_in && nb_func_data(func)->nargs != 0) {
// When the bindings define a custom __new__ operator, nanobind always
// provides a no-argument dummy __new__ constructor to handle unpickling
// via __setstate__. This is an implementation detail that should not be
// exposed. Therefore, only allow argument-less calls if there is an
// actual __new__ overload with a compatible signature.

return func->vectorcall((PyObject *) func, nullptr, 0, nullptr);
}

PyObject **args = nullptr, *temp = nullptr;
const size_t buf_size = 5;
PyObject **args, *buf[buf_size], *temp = nullptr;
bool alloc = false;

if (NB_LIKELY(nargsf & NB_VECTORCALL_ARGUMENTS_OFFSET)) {
args = (PyObject **) (args_in - 1);
temp = args[0];
} else {
size_t size = nargs;
size_t size = nargs + 1;
if (kwargs_in)
size += NB_TUPLE_GET_SIZE(kwargs_in);
args = (PyObject **) alloca(((size_t) size + 1) * sizeof(PyObject *));
if (size)
memcpy(args + 1, args_in, sizeof(PyObject *) * size);

if (size < buf_size) {
args = buf;
} else {
args = (PyObject **) PyMem_Malloc(size * sizeof(PyObject *));
if (!args) {
if (is_init)
Py_DECREF(self);
return PyErr_NoMemory();
}
alloc = true;
}

memcpy(args + 1, args_in, sizeof(PyObject *) * (size - 1));
}

args[0] = self;
Expand All @@ -909,6 +927,9 @@ static PyObject *nb_type_vectorcall(PyObject *self, PyObject *const *args_in,

args[0] = temp;

if (NB_UNLIKELY(alloc))
PyMem_Free(args);

if (NB_LIKELY(is_init)) {
if (!rv) {
Py_DECREF(self);
Expand Down

0 comments on commit e24d7f3

Please sign in to comment.