diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index b2017fdf..6e746997 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -424,9 +424,12 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req, } } + bool refused_conversion = t.dtype.code == (uint8_t) dlpack::dtype_code::Complex && + req->dtype.code != (uint8_t) dlpack::dtype_code::Complex; + // Support implicit conversion of 'dtype' and order if (pass_device && pass_shape && (!pass_dtype || !pass_order) && convert && - capsule.ptr() != o) { + capsule.ptr() != o && !refused_conversion) { PyTypeObject *tp = Py_TYPE(o); str module_name_o = borrow(handle(tp).attr("__module__")); const char *module_name = module_name_o.c_str(); diff --git a/tests/test_ndarray.cpp b/tests/test_ndarray.cpp index e1ef5ed6..0820cf30 100644 --- a/tests/test_ndarray.cpp +++ b/tests/test_ndarray.cpp @@ -296,4 +296,13 @@ NB_MODULE(test_ndarray_ext, m) { else return Ret(nb::ndarray>(i_global, 0, nullptr)); }); + + // issue #365 + m.def("set_item", [](nb::ndarray, nb::c_contig> data, uint32_t) { + data(0) = 123; + }); + m.def("set_item", + [](nb::ndarray, nb::ndim<1>, nb::c_contig> data, uint32_t) { + data(0) = 123; + }); } diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 195bf405..1bd1bab0 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -632,3 +632,15 @@ def test34_complex_decompose(): assert np.all(x1.real == np.array([1, 3, 5], dtype=np.float32)) assert np.all(x1.imag == np.array([2, 4, 6], dtype=np.float32)) + +@needs_numpy +@pytest.mark.parametrize("variant", [1, 2]) +def test_uint32_complex_do_not_convert(variant): + if variant == 1: + arg = 1 + else: + arg = np.uint32(1) + data = np.array([1.0 + 2.0j, 3.0 + 4.0j]) + t.set_item(data, arg) + data2 = np.array([123, 3.0 + 4.0j]) + assert np.all(data == data2)