From c289ec639a69781b9a366e37448b804da58dfdc5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 10 May 2022 09:41:34 -0700 Subject: [PATCH] Fix sddmm2 when nnz=0 Also add cudaGetLastError which was missing --- tests/test_custom_ops.py | 10 +++++----- xformers/components/attention/csrc/cuda/sddmm2_cuda.cu | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/test_custom_ops.py b/tests/test_custom_ops.py index 6d5053e52f..bef8b41021 100644 --- a/tests/test_custom_ops.py +++ b/tests/test_custom_ops.py @@ -158,14 +158,14 @@ def test_sddmm_sputnik(device): @cuda_only +@pytest.mark.parametrize("prob", [0.5, 1]) @pytest.mark.parametrize("K", [32, 17]) @pytest.mark.parametrize("M", [30, 17]) @pytest.mark.parametrize("L", [30, 17]) -def test_sddmm_csr(L, M, K): +def test_sddmm_csr(L, M, K, prob): device = torch.device("cuda") # TODO add more checks for different nnz B = 8 - prob = 0.5 a = torch.rand(B, L, K, device=device) b = torch.rand(B, M, K, device=device) mask = _create_random_sparsity( @@ -188,7 +188,7 @@ def test_sddmm_csr(L, M, K): @cuda_only -@pytest.mark.parametrize("nnz", [4, 16, 20, 36]) +@pytest.mark.parametrize("nnz", [0, 4, 16, 20, 36]) def test_sddmm_csr_per_nnz(nnz): device = torch.device("cuda") B = 8 @@ -215,14 +215,14 @@ def test_sddmm_csr_per_nnz(nnz): @cuda_only +@pytest.mark.parametrize("prob", [0.5, 1]) @pytest.mark.parametrize("K", [32, 17]) @pytest.mark.parametrize("M", [30, 17]) @pytest.mark.parametrize("L", [30, 17]) -def test_sddmm_coo(L, M, K): +def test_sddmm_coo(L, M, K, prob): device = torch.device("cuda") # TODO add more checks for different nnz B = 8 - prob = 0.5 a = torch.rand(B, L, K, device=device) b = torch.rand(B, M, K, device=device) mask = _create_random_sparsity( diff --git a/xformers/components/attention/csrc/cuda/sddmm2_cuda.cu b/xformers/components/attention/csrc/cuda/sddmm2_cuda.cu index 03945682a0..a6179b6193 100644 --- a/xformers/components/attention/csrc/cuda/sddmm2_cuda.cu +++ b/xformers/components/attention/csrc/cuda/sddmm2_cuda.cu @@ -457,6 +457,9 @@ torch::Tensor sddmm_cuda_coo( const auto nnz = rowind.size(0); auto out = torch::empty({batch_size, nnz}, D1.options()); + if (out.numel() == 0) + return out; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid_dim(nnz / 16 + (nnz & 15), batch_size, 1); if ((k % 4) == 0) { @@ -496,6 +499,7 @@ torch::Tensor sddmm_cuda_coo( D2.data_ptr(), out.data_ptr()); } + AT_CUDA_CHECK(cudaGetLastError()); return out; } @@ -511,6 +515,9 @@ torch::Tensor sddmm_cuda_csr( const auto nnz = colind.size(0); auto out = torch::empty({batch_size, nnz}, D1.options()); + if (out.numel() == 0) + return out; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid_dim(nnz / 16 + (nnz & 15), batch_size, 1); @@ -539,6 +546,7 @@ torch::Tensor sddmm_cuda_csr( D2.data_ptr(), out.data_ptr()); } + AT_CUDA_CHECK(cudaGetLastError()); return out; } } // namespace ge_spmm