Skip to content

Commit

Permalink
Fix ruff errors
Browse files Browse the repository at this point in the history
Ensure nan propagation is still handled correctly for CuPy sign().
  • Loading branch information
asmeurer committed Oct 24, 2024
1 parent dd44814 commit 2539057
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
9 changes: 6 additions & 3 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,14 +532,17 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:

# numpy 1.26 does not use the standard definition for sign on complex numbers

def sign(x: array, /, xp, **kwargs) -> array:
def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
if isdtype(x.dtype, 'complex floating', xp=xp):
out = (x/xp.abs(x, **kwargs))[...]
# sign(0) = 0 but the above formula would give nan
out[x == 0+0j] = 0+0j
return out[()]
else:
return xp.sign(x, **kwargs)
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
out[xp.isnan(x)] = xp.nan
return out[()]

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
Expand Down
7 changes: 0 additions & 7 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,6 @@ def asarray(

return cp.array(obj, dtype=dtype, **kwargs)

def sign(x: ndarray, /) -> ndarray:
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
out = cp.sign(x)
out[cp.isnan(x)] = cp.nan
return out

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
Expand Down

0 comments on commit 2539057

Please sign in to comment.