Skip to content

Commit

Permalink
Merge pull request #329 from jepler/array-memoryview-legacy
Browse files Browse the repository at this point in the history
ndarray: Fix memoryview(ulab.array(...))
  • Loading branch information
v923z authored Feb 19, 2021
2 parents 743d864 + 7aeb73a commit 8d15661
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
11 changes: 9 additions & 2 deletions code/ndarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -1986,8 +1986,15 @@ mp_obj_t ndarray_info(mp_obj_t obj_in) {
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_info_obj, ndarray_info);
#endif

// (the get_buffer protocol returns 0 for success, 1 for failure)
mp_int_t ndarray_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) {
ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in);
// buffer_p.get_buffer() returns zero for success, while mp_get_buffer returns true for success
return !mp_get_buffer(self->array, bufinfo, flags);
if (self->ndim != 1 || self->strides[0] > 1) {
// For now, only allow fetching buffer of a 1d-array
return 1;
}
bufinfo->len = self->itemsize * self->len;
bufinfo->buf = self->array;
bufinfo->typecode = self->dtype;
return 0;
}
16 changes: 16 additions & 0 deletions tests/common/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
try:
import ulab as np
except:
import numpy as np

def print_as_buffer(a):
print(len(memoryview(a)), list(memoryview(a)))
print_as_buffer(np.ones(3))
print_as_buffer(np.zeros(3))
print_as_buffer(np.ones(1, dtype=np.int8))
print_as_buffer(np.ones(2, dtype=np.uint8))
print_as_buffer(np.ones(3, dtype=np.int16))
print_as_buffer(np.ones(4, dtype=np.uint16))
print_as_buffer(np.ones(5, dtype=np.float))
print_as_buffer(np.linspace(0, 1, 9))

8 changes: 8 additions & 0 deletions tests/common/buffer.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
3 [1.0, 1.0, 1.0]
3 [0.0, 0.0, 0.0]
1 [1]
2 [1, 1]
3 [1, 1, 1]
4 [1, 1, 1, 1]
5 [1.0, 1.0, 1.0, 1.0, 1.0]
9 [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

0 comments on commit 8d15661

Please sign in to comment.