-
Notifications
You must be signed in to change notification settings - Fork 23.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing array-like sequence methods __contains__ #17733
Conversation
@colesbury where to add documentation for this? also, I don't think we need reversed and cotains aten implementation as this is high level functions and internal functions being aten should be okay |
hmm, strange, I thought reversed was already added in #9216 |
@bhushan23 the docs should go into _tensor_docs i think |
@pytorchbot retest this please |
torch/tensor.py
Outdated
@@ -426,6 +426,10 @@ def __array_wrap__(self, array): | |||
array = array.astype('uint8') | |||
return torch.from_numpy(array) | |||
|
|||
def __contains__(self, val): | |||
r"""Check if `val` is present""" | |||
return (val == self).any().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't you make sure that val
is a pytorch/python(or numpy?) scalar first, and return NotImplemented otherwise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about the case where we need to check if a row existed in a matrix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ssnl throwing NotImplemented in not tensor or scalar.
@vishwakftw NumPy also does not check if a row exists in a matrix.
Current behavior: any one element being present returns true.
Certainly, checking if row exists will be good to have.
Question is whether to make it default or under some option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also test is needed
for tensor Test added- 1. test_contains
torch/tensor.py
Outdated
""" | ||
if torch.is_tensor(element) or isinstance(element, Number): | ||
return (element == self).any().item() | ||
raise NotImplementedError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm I think you should return NotImplemented
rather than raising NotImplementedError
so that the default behavior using __iter__
can be invoked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strainge observation
>>> a = torch.arange(10)
>>> 'b' in a // this returns True, because of fall back to __iter__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
string 'b'
? that doesn't look right...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah.
I checked about NotImplemented.
Looks like we are typecasting it to bool and returning somewhere
bool(NotImplemented) is true
That's why it's not falling back to __iter__
based method
btw, why should we fall back to basic method as it will still lead to error if it's not tensor or scalar
And Isn't specific error by us preferable than relying on standard __iter__
based method to give random error.
ref: https://medium.com/@s16h/pythons-notimplemented-type-2d720137bf41
torch/tensor.py
Outdated
element (Tensor or scalar): element to be checked | ||
for presence in current tensor" | ||
""" | ||
if torch.is_tensor(element) or isinstance(element, Number): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use isinstance(element, (torch.Tensor, Number))
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
- returning NotImplemented instead of Error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
for tensor
Fixes: #17000