diff --git a/code/ndarray.c b/code/ndarray.c index 45af5f72..b4dcf825 100644 --- a/code/ndarray.c +++ b/code/ndarray.c @@ -1986,8 +1986,15 @@ mp_obj_t ndarray_info(mp_obj_t obj_in) { MP_DEFINE_CONST_FUN_OBJ_1(ndarray_info_obj, ndarray_info); #endif +// (the get_buffer protocol returns 0 for success, 1 for failure) mp_int_t ndarray_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) { ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in); - // buffer_p.get_buffer() returns zero for success, while mp_get_buffer returns true for success - return !mp_get_buffer(self->array, bufinfo, flags); + if (self->ndim != 1 || self->strides[0] > 1) { + // For now, only allow fetching buffer of a 1d-array + return 1; + } + bufinfo->len = self->itemsize * self->len; + bufinfo->buf = self->array; + bufinfo->typecode = self->dtype; + return 0; } diff --git a/tests/common/buffer.py b/tests/common/buffer.py new file mode 100644 index 00000000..5ea952c3 --- /dev/null +++ b/tests/common/buffer.py @@ -0,0 +1,16 @@ +try: + import ulab as np +except: + import numpy as np + +def print_as_buffer(a): + print(len(memoryview(a)), list(memoryview(a))) +print_as_buffer(np.ones(3)) +print_as_buffer(np.zeros(3)) +print_as_buffer(np.ones(1, dtype=np.int8)) +print_as_buffer(np.ones(2, dtype=np.uint8)) +print_as_buffer(np.ones(3, dtype=np.int16)) +print_as_buffer(np.ones(4, dtype=np.uint16)) +print_as_buffer(np.ones(5, dtype=np.float)) +print_as_buffer(np.linspace(0, 1, 9)) + diff --git a/tests/common/buffer.py.exp b/tests/common/buffer.py.exp new file mode 100644 index 00000000..0a91bb80 --- /dev/null +++ b/tests/common/buffer.py.exp @@ -0,0 +1,8 @@ +3 [1.0, 1.0, 1.0] +3 [0.0, 0.0, 0.0] +1 [1] +2 [1, 1] +3 [1, 1, 1] +4 [1, 1, 1, 1] +5 [1.0, 1.0, 1.0, 1.0, 1.0] +9 [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]