-
Notifications
You must be signed in to change notification settings - Fork 23.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix incorrect sparse add behavior when the sparse tensor has non-contiguous values #18179
Conversation
idk if this test failure is legit |
01dd5ba
to
6bff499
Compare
37577e3
to
a445698
Compare
a445698
to
17ec4cb
Compare
It would be really helpful review if the PR message explained how exactly the problem was solved. |
return r._coalesced_(t_coalesced && s_coalesced); | ||
LongTensor r_indices = at::cat({t_indices, s_indices}, 1); | ||
Tensor r_values = at::cat({t_values, s_values}, 0); | ||
alias_into_sparse(r, r_indices, r_values); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you cat'ed, don't you have to specify the output is not coalesced
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should make this a parameter on alias_into_sparse
so people have to consider it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alias_into_sparse(...)
calls set_indices_and_values_unsafe(...)
internally which always sets coalesced_ = false
, and we expect users to call sparse_tensor._coalesced_(...)
afterwards if they want to change the coalesce-ness of the sparse tensor. For example:
pytorch/aten/src/ATen/native/sparse/SparseTensor.cpp
Lines 457 to 458 in 1c671c5
alias_into_sparse(r, mask_indices.clone(), r_values); | |
r._coalesced_(mask.is_coalesced()); |
To simplify this API, we can add an is_coalesced
parameter on alias_into_sparse
, possibly in a separate PR.
Can we get some benchmark numbers? I'm not sure if some of our embedding examples exercise sparse-sparse, but if it does that would be most representative. I don't think it's necessarily wrong to switch to cat'ing the indices and values together, but I feel you could have also fixed the problem by simply switching values to use an accessor (which respects strides) rather than pointer arithmetic (which doesn't). So the algorithm change should be justified. |
How about only catting if the tensors aren't contiguous? That way we only (potentially) slow down paths that were broken anyway. |
int64_t blockSize = r_values.stride(0); | ||
int64_t cmp, d; | ||
int64_t r_i = 0, t_i = 0, s_i = 0; | ||
if (s_values.is_contiguous() && t_values.is_contiguous()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no change in this if-branch compared to the original code - I only indented it.
// index goes backwards) which may be more precise than using the | ||
// coalesced flag here. But this is easy. | ||
return r._coalesced_(t_coalesced && s_coalesced); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if-branch is the actual addition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…iguous values (#18179) Summary: Currently, this code gives incorrect result: ```python import torch indices=torch.tensor([[7, 1, 3]]) values=torch.tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) x = torch.sparse_coo_tensor(indices, values, size=(10, 3)) values=torch.tensor(1.).expand(3, 3) y = torch.sparse_coo_tensor(indices, values, size=(10, 3)) z = x + y tensor(indices=tensor([[7, 1, 3]]), values=tensor([[2., 1., 1.], [1., 1., 1.], [1., 1., 1.]]), size=(10, 3), nnz=3, layout=torch.sparse_coo) ``` This PR fixes the bug by adding special handling for sparse tensors with non-contiguous values in the addition function (specifically, by cat'ing the indices and values together). This PR closes pytorch/pytorch#17950 and pytorch/pytorch#17919. Pull Request resolved: pytorch/pytorch#18179 Reviewed By: ezyang Differential Revision: D14569591 Pulled By: yf225 fbshipit-source-id: f5a14c4a31337fc95eab64596212066b4fb18b1a
Currently, this code gives incorrect result:
This PR fixes the bug by adding special handling for sparse tensors with non-contiguous values in the addition function (specifically, by cat'ing the indices and values together).
This PR closes #17950 and #17919.