diff --git a/src/tensor.cpp b/src/tensor.cpp index dd53ad5f..0c0b0f1c 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -544,6 +544,7 @@ static void tensor_capsule_destructor(PyObject *o) { error_scope scope; // temporarily save any existing errors managed_tensor *mt = (managed_tensor *) PyCapsule_GetPointer(o, "dltensor"); + if (mt) tensor_dec_ref((tensor_handle *) mt->manager_ctx); else @@ -551,39 +552,48 @@ static void tensor_capsule_destructor(PyObject *o) { } PyObject *tensor_wrap(tensor_handle *th, int framework) noexcept { + if (!th) + return none().release().ptr(); + tensor_inc_ref(th); - object o = steal(PyCapsule_New(th->tensor, "dltensor", tensor_capsule_destructor)), + object o = steal(PyCapsule_New(th->tensor, "dltensor", + tensor_capsule_destructor)), package; - switch ((tensor_framework) framework) { - case tensor_framework::none: - break; + try { + switch ((tensor_framework) framework) { + case tensor_framework::none: + break; - case tensor_framework::numpy: - package = module_::import_("numpy"); - o = handle(internals_get().nb_tensor)(o); - break; + case tensor_framework::numpy: + package = module_::import_("numpy"); + o = handle(internals_get().nb_tensor)(o); + break; - case tensor_framework::pytorch: - package = module_::import_("torch.utils.dlpack"); - break; + case tensor_framework::pytorch: + package = module_::import_("torch.utils.dlpack"); + break; - case tensor_framework::tensorflow: - package = module_::import_("tensorflow.experimental.dlpack"); - break; + case tensor_framework::tensorflow: + package = module_::import_("tensorflow.experimental.dlpack"); + break; - case tensor_framework::jax: - package = module_::import_("jax.dlpack"); - break; + case tensor_framework::jax: + package = module_::import_("jax.dlpack"); + break; - default: - fail("nanobind::detail::tensor_wrap(): unknown framework " - "specified!"); + default: + fail("nanobind::detail::tensor_wrap(): unknown framework " + "specified!"); + } + } catch (const std::exception &e) { + PyErr_Format(PyExc_RuntimeError, + "Could not import tensor framework: %s", e.what()); + return nullptr; } - if (package.is_valid()) { try { o = package.attr("from_dlpack")(o); diff --git a/tests/test_tensor.cpp b/tests/test_tensor.cpp index 323f1da8..9bbf4314 100644 --- a/tests/test_tensor.cpp +++ b/tests/test_tensor.cpp @@ -136,5 +136,5 @@ NB_MODULE(test_tensor_ext, m) { }); return nb::tensor(f, 0, shape, deleter); - }); + }); }