Skip to content

Commit

Permalink
Structure the copy flag in cupy.asarray better
Browse files Browse the repository at this point in the history
This way it is more future-proof for when cupy changes the meaning of
copy=False.
  • Loading branch information
asmeurer committed Mar 22, 2024
1 parent a1eea09 commit d84983a
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)

_copy_default = object()

# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: Union[
Expand All @@ -75,7 +77,7 @@ def asarray(
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: Optional[bool] = None,
copy: Optional[bool] = _copy_default,
**kwargs,
) -> ndarray:
"""
Expand All @@ -90,12 +92,23 @@ def asarray(
with cp.cuda.Device(device):
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
# in asarray in numpy/_aliases.py.
if copy is None:
copy = False
elif copy is False:
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
if copy is not _copy_default:
# A future version of CuPy will change the meaning of copy=False
# to mean no-copy. We don't know for certain what version it will
# be yet, so to avoid breaking that version, we use a different
# default value for copy so asarray(obj) with no copy kwarg will
# always do the copy-if-needed behavior.

# This will still need to be updated to remove the
# NotImplementedError for copy=False, but at least this won't
# break the default or existing behavior.
if copy is None:
copy = False
elif copy is False:
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
kwargs['copy'] = copy

return cp.array(obj, copy=copy, dtype=dtype, **kwargs)
return cp.array(obj, dtype=dtype, **kwargs)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
Expand Down

0 comments on commit d84983a

Please sign in to comment.