Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: normalize NDArray to tensors, add a special-case for out= NDArrays #108

Merged
merged 1 commit into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 35 additions & 31 deletions torch_np/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AxisLike,
DTypeLike,
NDArray,
OutArray,
SubokLike,
normalize_array_like,
)
Expand All @@ -41,8 +42,8 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
def copyto(dst: NDArray, src: ArrayLike, casting="same_kind", where=NoValue):
if where is not NoValue:
raise NotImplementedError
(src,) = _util.typecast_tensors((src,), dst.tensor.dtype, casting=casting)
dst.tensor.copy_(src)
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
dst.copy_(src)


def atleast_1d(*arys: ArrayLike):
Expand Down Expand Up @@ -114,7 +115,7 @@ def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
def concatenate(
ar_tuple: Sequence[ArrayLike],
axis=0,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
dtype: DTypeLike = None,
casting="same_kind",
):
Expand Down Expand Up @@ -160,7 +161,7 @@ def column_stack(
def stack(
arrays: Sequence[ArrayLike],
axis=0,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
*,
dtype: DTypeLike = None,
casting="same_kind",
Expand Down Expand Up @@ -754,7 +755,7 @@ def nanmean(
a: ArrayLike,
axis=None,
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
keepdims=NoValue,
*,
where=NoValue,
Expand Down Expand Up @@ -892,7 +893,7 @@ def take(
a: ArrayLike,
indices: ArrayLike,
axis=None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
mode="raise",
):
if mode != "raise":
Expand Down Expand Up @@ -975,7 +976,7 @@ def clip(
a: ArrayLike,
min: Optional[ArrayLike] = None,
max: Optional[ArrayLike] = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
):
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
# one of them to be None. Follow the more lax version.
Expand Down Expand Up @@ -1070,7 +1071,7 @@ def trace(
axis1=0,
axis2=1,
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
):
result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype)
return result
Expand Down Expand Up @@ -1180,7 +1181,7 @@ def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
return torch.tensordot(a, b, dims=axes)


def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
a = _util.cast_if_needed(a, dtype)
b = _util.cast_if_needed(b, dtype)
Expand Down Expand Up @@ -1215,7 +1216,7 @@ def inner(a: ArrayLike, b: ArrayLike, /):
return result


def outer(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
return torch.outer(a, b)


Expand Down Expand Up @@ -1382,7 +1383,7 @@ def imag(a: ArrayLike):
return result


def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
if a.is_floating_point():
result = torch.round(a, decimals=decimals)
elif a.is_complex():
Expand All @@ -1408,7 +1409,7 @@ def sum(
a: ArrayLike,
axis: AxisLike = None,
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
keepdims=NoValue,
initial=NoValue,
where=NoValue,
Expand All @@ -1423,7 +1424,7 @@ def prod(
a: ArrayLike,
axis: AxisLike = None,
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
keepdims=NoValue,
initial=NoValue,
where=NoValue,
Expand All @@ -1441,7 +1442,7 @@ def mean(
a: ArrayLike,
axis: AxisLike = None,
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
keepdims=NoValue,
*,
where=NoValue,
Expand All @@ -1454,7 +1455,7 @@ def var(
a: ArrayLike,
axis: AxisLike = None,
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
ddof=0,
keepdims=NoValue,
*,
Expand All @@ -1470,7 +1471,7 @@ def std(
a: ArrayLike,
axis: AxisLike = None,
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
ddof=0,
keepdims=NoValue,
*,
Expand All @@ -1485,7 +1486,7 @@ def std(
def argmin(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
*,
keepdims=NoValue,
):
Expand All @@ -1496,7 +1497,7 @@ def argmin(
def argmax(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
*,
keepdims=NoValue,
):
Expand All @@ -1507,7 +1508,7 @@ def argmax(
def amax(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
keepdims=NoValue,
initial=NoValue,
where=NoValue,
Expand All @@ -1522,7 +1523,7 @@ def amax(
def amin(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
keepdims=NoValue,
initial=NoValue,
where=NoValue,
Expand All @@ -1535,7 +1536,10 @@ def amin(


def ptp(
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
a: ArrayLike,
axis: AxisLike = None,
out: Optional[OutArray] = None,
keepdims=NoValue,
):
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
return result
Expand All @@ -1544,7 +1548,7 @@ def ptp(
def all(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
keepdims=NoValue,
*,
where=NoValue,
Expand All @@ -1556,7 +1560,7 @@ def all(
def any(
a: ArrayLike,
axis: AxisLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
keepdims=NoValue,
*,
where=NoValue,
Expand All @@ -1574,7 +1578,7 @@ def cumsum(
a: ArrayLike,
axis: AxisLike = None,
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
):
result = _impl.cumsum(a, axis=axis, dtype=dtype)
return result
Expand All @@ -1584,7 +1588,7 @@ def cumprod(
a: ArrayLike,
axis: AxisLike = None,
dtype: DTypeLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
):
result = _impl.cumprod(a, axis=axis, dtype=dtype)
return result
Expand All @@ -1597,7 +1601,7 @@ def quantile(
a: ArrayLike,
q: ArrayLike,
axis: AxisLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
overwrite_input=False,
method="linear",
keepdims=False,
Expand All @@ -1620,7 +1624,7 @@ def percentile(
a: ArrayLike,
q: ArrayLike,
axis: AxisLike = None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
overwrite_input=False,
method="linear",
keepdims=False,
Expand All @@ -1642,7 +1646,7 @@ def percentile(
def median(
a: ArrayLike,
axis=None,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
overwrite_input=False,
keepdims=False,
):
Expand Down Expand Up @@ -1726,7 +1730,7 @@ def imag(a: ArrayLike):
return result


def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
if a.is_floating_point():
result = torch.round(a, decimals=decimals)
elif a.is_complex():
Expand Down Expand Up @@ -1786,11 +1790,11 @@ def isrealobj(x: ArrayLike):
return not torch.is_complex(x)


def isneginf(x: ArrayLike, out: Optional[NDArray] = None):
def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
return torch.isneginf(x, out=out)


def isposinf(x: ArrayLike, out: Optional[NDArray] = None):
def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
return torch.isposinf(x, out=out)


Expand Down
19 changes: 19 additions & 0 deletions torch_np/_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SubokLike = typing.TypeVar("SubokLike")
AxisLike = typing.TypeVar("AxisLike")
NDArray = typing.TypeVar("NDarray")
OutArray = typing.TypeVar("OutArray")


import inspect
Expand Down Expand Up @@ -60,6 +61,19 @@ def normalize_axis_like(arg, name=None):


def normalize_ndarray(arg, name=None):
# check the arg is an ndarray, extract its tensor attribute
if arg is None:
return arg

from ._ndarray import ndarray

if not isinstance(arg, ndarray):
raise TypeError("'out' must be an array")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should improve the error message using the name. Just patch this into the next PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return arg.tensor


def normalize_outarray(arg, name=None):
# almost normalize_ndarray, only return the array, not its tensor
if arg is None:
return arg

Expand All @@ -75,6 +89,8 @@ def normalize_ndarray(arg, name=None):
Optional[ArrayLike]: normalize_optional_array_like,
Sequence[ArrayLike]: normalize_seq_array_like,
Optional[NDArray]: normalize_ndarray,
Optional[OutArray]: normalize_outarray,
NDArray: normalize_ndarray,
DTypeLike: normalize_dtype,
SubokLike: normalize_subok_like,
AxisLike: normalize_axis_like,
Expand Down Expand Up @@ -164,6 +180,9 @@ def wrapped(*args, **kwds):

if "out" in params:
out = sig.bind(*args, **kwds).arguments.get("out")

### if out is not None: breakpoint()

ev-br marked this conversation as resolved.
Show resolved Hide resolved
result = maybe_copy_to(out, result, promote_scalar_result)
result = wrap_tensors(result)

Expand Down
14 changes: 7 additions & 7 deletions torch_np/_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from . import _binary_ufuncs_impl, _helpers, _unary_ufuncs_impl
from ._detail import _dtypes_impl, _util
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
from ._normalizations import ArrayLike, DTypeLike, OutArray, SubokLike, normalizer


def _ufunc_preprocess(tensors, where, casting, order, dtype, subok, signature, extobj):
Expand Down Expand Up @@ -46,7 +46,7 @@ def wrapped(
x1: ArrayLike,
x2: ArrayLike,
/,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
*,
where=True,
casting="same_kind",
Expand Down Expand Up @@ -80,7 +80,7 @@ def matmul(
x1: ArrayLike,
x2: ArrayLike,
/,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
*,
casting="same_kind",
order="K",
Expand Down Expand Up @@ -109,10 +109,10 @@ def matmul(
def divmod(
x1: ArrayLike,
x2: ArrayLike,
out1: Optional[NDArray] = None,
out2: Optional[NDArray] = None,
out1: Optional[OutArray] = None,
out2: Optional[OutArray] = None,
/,
out: tuple[Optional[NDArray], Optional[NDArray]] = (None, None),
out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
*,
where=True,
casting="same_kind",
Expand Down Expand Up @@ -181,7 +181,7 @@ def deco_unary_ufunc(torch_func):
def wrapped(
x: ArrayLike,
/,
out: Optional[NDArray] = None,
out: Optional[OutArray] = None,
*,
where=True,
casting="same_kind",
Expand Down
Loading