From 23ea3202294a5f3758b6d496404a875527386555 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Wed, 23 Aug 2023 12:23:33 +0200 Subject: [PATCH] Fix implicit conversion of ``nb::ndarray`` to contiguous layout --- src/nb_ndarray.cpp | 17 +++++++++-------- tests/test_ndarray.cpp | 2 ++ tests/test_ndarray.py | 21 +++++++++++++++++++++ 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index 9e36c68a..3d3d6c11 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -415,26 +415,27 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req, str module_name_o = borrow(handle(tp).attr("__module__")); const char *module_name = module_name_o.c_str(); - char order = 'K'; + char order = 'K'; // for NumPy. 'K' means 'keep' if (req->req_order != '\0') order = req->req_order; - if (req->dtype.lanes != 1) + dlpack::dtype dt = req->req_dtype ? req->dtype : t.dtype; + if (dt.lanes != 1) return nullptr; const char *prefix = nullptr; char dtype[9]; - if (req->dtype.code == (uint8_t) dlpack::dtype_code::Bool) { + if (dt.code == (uint8_t) dlpack::dtype_code::Bool) { std::strcpy(dtype, "bool"); } else { - switch (req->dtype.code) { + switch (dt.code) { case (uint8_t) dlpack::dtype_code::Int: prefix = "int"; break; case (uint8_t) dlpack::dtype_code::UInt: prefix = "uint"; break; case (uint8_t) dlpack::dtype_code::Float: prefix = "float"; break; default: return nullptr; } - snprintf(dtype, sizeof(dtype), "%s%u", prefix, req->dtype.bits); + snprintf(dtype, sizeof(dtype), "%s%u", prefix, dt.bits); } object converted; @@ -443,9 +444,9 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req, converted = handle(o).attr("astype")(dtype, order); } else if (strcmp(module_name, "torch") == 0) { converted = handle(o).attr("to")( - arg("dtype") = module_::import_("torch").attr(dtype), - arg("copy") = true - ); + arg("dtype") = module_::import_("torch").attr(dtype)); + if (req->req_order == 'C') + converted = converted.attr("contiguous")(); } else if (strncmp(module_name, "tensorflow.", 11) == 0) { converted = module_::import_("tensorflow") .attr("cast")(handle(o), dtype); diff --git a/tests/test_ndarray.cpp b/tests/test_ndarray.cpp index bb0a52ec..d585d4c6 100644 --- a/tests/test_ndarray.cpp +++ b/tests/test_ndarray.cpp @@ -69,6 +69,8 @@ NB_MODULE(test_ndarray_ext, m) { m.def("check_order", [](nb::ndarray) -> char { return 'F'; }); m.def("check_order", [](nb::ndarray<>) -> char { return '?'; }); + m.def("make_contig", [](nb::ndarray a) { return a; }); + m.def("check_device", [](nb::ndarray) -> const char * { return "cpu"; }); m.def("check_device", [](nb::ndarray) -> const char * { return "cuda"; }); diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 29973a56..aea7a843 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -540,3 +540,24 @@ def test28_reference_internal(): msg = 'nanobind::detail::ndarray_wrap(): reference_internal policy cannot be applied (ndarray already has an owner)' assert msg in str(excinfo.value) + +@needs_numpy +def test29_force_contig_pytorch(): + a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + b = t.make_contig(a) + assert b is a + a = a.T + b = t.make_contig(a) + assert b is not a + assert np.all(b == a) + +@needs_torch +@pytest.mark.filterwarnings +def test30_force_contig_pytorch(): + a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + b = t.make_contig(a) + assert b is a + a = a.T + b = t.make_contig(a) + assert b is not a + assert torch.all(b == a)