Skip to content

Commit

Permalink
Don't implicitly convert complex to non-complex ndarrays (fixes issues
Browse files Browse the repository at this point in the history
…#364, 3365)
  • Loading branch information
wjakob committed Nov 16, 2023
1 parent fa121bd commit ea2569f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/nb_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,12 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
}
}

bool refused_conversion = t.dtype.code == (uint8_t) dlpack::dtype_code::Complex &&
req->dtype.code != (uint8_t) dlpack::dtype_code::Complex;

// Support implicit conversion of 'dtype' and order
if (pass_device && pass_shape && (!pass_dtype || !pass_order) && convert &&
capsule.ptr() != o) {
capsule.ptr() != o && !refused_conversion) {
PyTypeObject *tp = Py_TYPE(o);
str module_name_o = borrow<str>(handle(tp).attr("__module__"));
const char *module_name = module_name_o.c_str();
Expand Down
9 changes: 9 additions & 0 deletions tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,13 @@ NB_MODULE(test_ndarray_ext, m) {
else
return Ret(nb::ndarray<nb::numpy, int, nb::shape<>>(i_global, 0, nullptr));
});

// issue #365
m.def("set_item", [](nb::ndarray<double, nb::ndim<1>, nb::c_contig> data, uint32_t) {
data(0) = 123;
});
m.def("set_item",
[](nb::ndarray<std::complex<double>, nb::ndim<1>, nb::c_contig> data, uint32_t) {
data(0) = 123;
});
}
12 changes: 12 additions & 0 deletions tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,3 +632,15 @@ def test34_complex_decompose():

assert np.all(x1.real == np.array([1, 3, 5], dtype=np.float32))
assert np.all(x1.imag == np.array([2, 4, 6], dtype=np.float32))

@needs_numpy
@pytest.mark.parametrize("variant", [1, 2])
def test_uint32_complex_do_not_convert(variant):
if variant == 1:
arg = 1
else:
arg = np.uint32(1)
data = np.array([1.0 + 2.0j, 3.0 + 4.0j])
t.set_item(data, arg)
data2 = np.array([123, 3.0 + 4.0j])
assert np.all(data == data2)

0 comments on commit ea2569f

Please sign in to comment.