Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sddmm2 when nnz=0 #300

Merged
merged 1 commit into from
May 10, 2022
Merged

Fix sddmm2 when nnz=0 #300

merged 1 commit into from
May 10, 2022

Conversation

fmassa
Copy link
Contributor

@fmassa fmassa commented May 10, 2022

One of the internal implementations of sampled dense dense matrix multiplication (sddmm) that we have had two issues:

  • It didn't guard kernel launches when nnz=0 (which would yield a grid size of 0 in one of its dimensions)
  • it didn't contain a cudaGetLastError call in the end of the functions. So errors in this kernel would only be reported at the next function invocation, misleading the true location of the issue.

The PR fixes this by returning early when nnz=0 (which is fine as there is no data in the tensor anyway), and also add cudaGetLastError which was missing.

Should fix the errors present in #263

Also add cudaGetLastError which was missing
@fmassa fmassa requested a review from blefaudeux May 10, 2022 16:45
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 10, 2022
@fmassa fmassa mentioned this pull request May 10, 2022
10 tasks
@@ -158,14 +158,14 @@ def test_sddmm_sputnik(device):


@cuda_only
@pytest.mark.parametrize("prob", [0.5, 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding this, nice catch

@@ -539,6 +546,7 @@ torch::Tensor sddmm_cuda_csr(
D2.data_ptr<float>(),
out.data_ptr<float>());
}
AT_CUDA_CHECK(cudaGetLastError());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, makes sense :)

Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot @fmassa for diving in and the very quick fix

@blefaudeux
Copy link
Contributor

the mypy error is fixed on main, will be fine on landing

@fmassa fmassa merged commit bcedfaf into main May 10, 2022
@fmassa fmassa deleted the sddmm2-zero-nnz-fix branch May 10, 2022 17:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants