Skip to content

Commit

Permalink
FIX trigger indptr/indices copy when data are copied in astype (scipy…
Browse files Browse the repository at this point in the history
…#18192)

* FIX trigger indptr/indices copy when data are copied in astype

* TST add unit tests

* iter

* Add comment regarding object memview support
  • Loading branch information
glemaitre authored and tylerjereddy committed Apr 16, 2023
1 parent fcf8009 commit 80d2d9e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
12 changes: 10 additions & 2 deletions scipy/sparse/_csparsetools.pyx.in
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,15 @@ def _lil_get_lengths_{{NAME}}(object[:] input,

{{define_dispatch_map('_LIL_GET_LENGTHS_DISPATCH', '_lil_get_lengths', IDX_TYPES)}}

def lil_flatten_to_array(object[:] input,
# We define the fuse type below because Cython does not currently allow to
# declare object memory views (cf. https://github.com/cython/cython/issues/2485)
# We can track the support of object memory views in
# https://github.com/cython/cython/pull/4712
ctypedef fused obj_fused:
object
double

def lil_flatten_to_array(const obj_fused[:] input,
cnp.ndarray output):
return _LIL_FLATTEN_TO_ARRAY_DISPATCH[output.dtype](input, output)

Expand Down Expand Up @@ -311,7 +319,7 @@ def _lil_fancy_set_{{PYIDX}}_{{PYVALUE}}(cnp.npy_intp M, cnp.npy_intp N,


def lil_get_row_ranges(cnp.npy_intp M, cnp.npy_intp N,
object[:] rows, object[:] datas,
const obj_fused[:] rows, const obj_fused[:] datas,
object[:] new_rows, object[:] new_datas,
object irows,
cnp.npy_intp j_start,
Expand Down
8 changes: 5 additions & 3 deletions scipy/sparse/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def __itruediv__(self, other): # self /= other
def astype(self, dtype, casting='unsafe', copy=True):
dtype = np.dtype(dtype)
if self.dtype != dtype:
return self._with_data(
self._deduped_data().astype(dtype, casting=casting, copy=copy),
copy=copy)
matrix = self._with_data(
self.data.astype(dtype, casting=casting, copy=True),
copy=True
)
return matrix._with_data(matrix._deduped_data(), copy=False)
elif copy:
return self.copy()
else:
Expand Down
18 changes: 18 additions & 0 deletions scipy/sparse/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,24 @@ def check_equal_but_not_same_array_attribute(attribute):
for attribute in ('offsets', 'data'):
check_equal_but_not_same_array_attribute(attribute)

@sup_complex
def test_astype_immutable(self):
D = array([[2.0 + 3j, 0, 0],
[0, 4.0 + 5j, 0],
[0, 0, 0]])
S = self.spmatrix(D)
if hasattr(S, 'data'):
S.data.flags.writeable = False
if hasattr(S, 'indptr'):
S.indptr.flags.writeable = False
if hasattr(S, 'indices'):
S.indices.flags.writeable = False
for x in supported_dtypes:
D_casted = D.astype(x)
S_casted = S.astype(x)
assert_equal(S_casted.dtype, D_casted.dtype)


def test_asfptype(self):
A = self.spmatrix(arange(6,dtype='int32').reshape(2,3))

Expand Down

0 comments on commit 80d2d9e

Please sign in to comment.