-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MAINT: normalize NDArray to tensors, add a special-case for out= NDArrays #108
Conversation
Arguments annotated as NDArray get normalized to their Tensors for further processing by implementer functions. out= arguments however are special (the original array need to be preserved), and are never seen by implementers, so they keep being ndarrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you change NDArray
to OutArray
?
Note that the only use of out
is through out.tensor
numpy_pytorch_interop/torch_np/_normalizations.py
Lines 123 to 124 in f30f4c4
out.tensor.copy_(result) | |
return out |
My best guess is that the Optional[NDArray]
annotation we used before was not correct when we had tuples or lists as a return.
For reasons like this, I proposed to move out the out=
implementation from the function into the decorator itself, as this input is different to all the others. Now, to avoid the extra packing-unpacking code, we need to make an exception an leak ndarray
s into _funcs_impl.py
, but well...
Before this PR:
Now, if we make The OutArray annotation is to fix this. In the end of the day, out array is different from just a generic ndarray argument. NumPy almost never requires strict ndarrays --- here this is literally the only exception.
We discussed this a couple of times actually. This breaks using out as a positional argument, allowed by numpy. So my inclination would be to just accept |
Yes. This is separate though. I'm fixing this using divmod as a guinea pig now. The current stumbling block is how to check that |
As mentioned in those discussions, it can be done. Here's a simplified example import inspect
def add_out(f):
def wrapper(*args, **kwargs):
if "out" in kwargs:
out = kwargs.pop("out")
elif len(args) > len(inspect.getfullargspec(f).args):
args, out = args[:-1], args[-1]
print(out)
result = f(*args, **kwargs)
return result
return wrapper
@add_out
def f(x, *, y=3):
pass
f(2,7)
f(2, out=8) This could go under a |
This is still brittle because there is no guarantee that out is the last positional arg. It really can be anywhere in the argument list. |
Fine. Let's just have it like this then. As a separate point, the recursive step in As for the typing issue, why don't you check that it's exactly a |
As for how to do that last bit, there's no nice way of doing it in the |
This PR, with a special |
This PR as is looks alright. Let's write a comment for what's the rationale behind this though (out no being kwarg only and being on arbitrary positions) |
Re copy_to and sequences --- yes, I've noticed and am fixing it. As a side note, we need to set up coverage to smoke out these unusual code paths. PrimTorch's way is slick indeed. I was hoping there is a way to avoid looking up |
Do the primtorch trick plus: >>> import typing
>>> issubclass(tuple, typing.Sequence)
True
>>> issubclass(list, typing.Sequence)
True |
from ._ndarray import ndarray | ||
|
||
if not isinstance(arg, ndarray): | ||
raise TypeError("'out' must be an array") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should improve the error message using the name. Just patch this into the next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, gh-109 does a minimal fix for out= being tuple only, and avoids doing one more refactor. Let's keep the primtorch trick for if/when we hit limitations of the current infra. |
Arguments annotated as NDArray get normalized to their Tensors for further processing by implementer functions.
out= arguments however are special (the original array need to be preserved), and are never seen by implementers, so they keep being ndarrays.
So the annotations and their corresponding normalizations are
normalizer
level