Skip to content

Commit

Permalink
Fix implicit conversion of nb::ndarray to contiguous layout
Browse files Browse the repository at this point in the history
nanobind ndarrays provide the ``nb::c_contig`` and ``nb::f_contig``
annotations to specify that input arrays must be represented by
contiguous memory blocks in C or Fortran-style ordering. When this is
not the case, the nanobind will by default attempt an implicit
conversion.

This conversion previously failed in some cases: when no underlying
scalar type was specified, and when converting from PyTorch. Those
issues are addressed by this commit.

Fixes issue #278.
  • Loading branch information
wjakob committed Aug 24, 2023
1 parent e4b9a9f commit ed929b7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/nb_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,26 +415,27 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
str module_name_o = borrow<str>(handle(tp).attr("__module__"));
const char *module_name = module_name_o.c_str();

char order = 'K';
char order = 'K'; // for NumPy. 'K' means 'keep'
if (req->req_order != '\0')
order = req->req_order;

if (req->dtype.lanes != 1)
dlpack::dtype dt = req->req_dtype ? req->dtype : t.dtype;
if (dt.lanes != 1)
return nullptr;

const char *prefix = nullptr;
char dtype[9];
if (req->dtype.code == (uint8_t) dlpack::dtype_code::Bool) {
if (dt.code == (uint8_t) dlpack::dtype_code::Bool) {
std::strcpy(dtype, "bool");
} else {
switch (req->dtype.code) {
switch (dt.code) {
case (uint8_t) dlpack::dtype_code::Int: prefix = "int"; break;
case (uint8_t) dlpack::dtype_code::UInt: prefix = "uint"; break;
case (uint8_t) dlpack::dtype_code::Float: prefix = "float"; break;
default:
return nullptr;
}
snprintf(dtype, sizeof(dtype), "%s%u", prefix, req->dtype.bits);
snprintf(dtype, sizeof(dtype), "%s%u", prefix, dt.bits);
}

object converted;
Expand All @@ -443,9 +444,9 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
converted = handle(o).attr("astype")(dtype, order);
} else if (strcmp(module_name, "torch") == 0) {
converted = handle(o).attr("to")(
arg("dtype") = module_::import_("torch").attr(dtype),
arg("copy") = true
);
arg("dtype") = module_::import_("torch").attr(dtype));
if (req->req_order == 'C')
converted = converted.attr("contiguous")();
} else if (strncmp(module_name, "tensorflow.", 11) == 0) {
converted = module_::import_("tensorflow")
.attr("cast")(handle(o), dtype);
Expand Down
2 changes: 2 additions & 0 deletions tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ NB_MODULE(test_ndarray_ext, m) {
m.def("check_order", [](nb::ndarray<nb::f_contig>) -> char { return 'F'; });
m.def("check_order", [](nb::ndarray<>) -> char { return '?'; });

m.def("make_contig", [](nb::ndarray<nb::c_contig> a) { return a; });

m.def("check_device", [](nb::ndarray<nb::device::cpu>) -> const char * { return "cpu"; });
m.def("check_device", [](nb::ndarray<nb::device::cuda>) -> const char * { return "cuda"; });

Expand Down
21 changes: 21 additions & 0 deletions tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,24 @@ def test28_reference_internal():

msg = 'nanobind::detail::ndarray_wrap(): reference_internal policy cannot be applied (ndarray already has an owner)'
assert msg in str(excinfo.value)

@needs_numpy
def test29_force_contig_pytorch():
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = t.make_contig(a)
assert b is a
a = a.T
b = t.make_contig(a)
assert b is not a
assert np.all(b == a)

@needs_torch
@pytest.mark.filterwarnings
def test30_force_contig_pytorch():
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = t.make_contig(a)
assert b is a
a = a.T
b = t.make_contig(a)
assert b is not a
assert torch.all(b == a)

0 comments on commit ed929b7

Please sign in to comment.