diff --git a/.github/workflows/dependencies/dependencies_nvcc.sh b/.github/workflows/dependencies/dependencies_nvcc.sh index 2578bd33fe7..14bae699d7e 100755 --- a/.github/workflows/dependencies/dependencies_nvcc.sh +++ b/.github/workflows/dependencies/dependencies_nvcc.sh @@ -36,5 +36,6 @@ sudo apt-get install -y \ cuda-nvml-dev-$VERSION_DASHED \ cuda-nvtx-$VERSION_DASHED \ libcufft-dev-$VERSION_DASHED \ - libcurand-dev-$VERSION_DASHED + libcurand-dev-$VERSION_DASHED \ + libcusparse-dev-$VERSION_DASHED sudo ln -s cuda-$VERSION_DOTTED /usr/local/cuda diff --git a/Src/Base/AMReX_TableData.H b/Src/Base/AMReX_TableData.H index ee2471d36cb..02907fb089a 100644 --- a/Src/Base/AMReX_TableData.H +++ b/Src/Base/AMReX_TableData.H @@ -15,12 +15,12 @@ namespace amrex { -template +template struct Table1D { T* AMREX_RESTRICT p = nullptr; - int begin = 1; - int end = 0; + IDX begin = 1; + IDX end = 0; constexpr Table1D () noexcept = default; @@ -33,7 +33,7 @@ struct Table1D {} AMREX_GPU_HOST_DEVICE - constexpr Table1D (T* a_p, int a_begin, int a_end) noexcept + constexpr Table1D (T* a_p, IDX a_begin, IDX a_end) noexcept : p(a_p), begin(a_begin), end(a_end) @@ -44,7 +44,7 @@ struct Table1D template ,int> = 0> AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE - U& operator() (int i) const noexcept { + U& operator() (IDX i) const noexcept { #if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK) index_assert(i); #endif @@ -53,14 +53,30 @@ struct Table1D #if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK) AMREX_GPU_HOST_DEVICE inline - void index_assert (int i) const + void index_assert (IDX i) const { if (i < begin || i >= end) { - AMREX_IF_ON_DEVICE(( - AMREX_DEVICE_PRINTF(" (%d) is out of bound (%d:%d)\n", - i, begin, end-1); - amrex::Abort(); - )) + if constexpr (std::is_same_v) { + AMREX_IF_ON_DEVICE(( + AMREX_DEVICE_PRINTF(" (%d) is out of bound (%d:%d)\n", + i, begin, end-1); + amrex::Abort(); + )) + } else if constexpr (std::is_same_v) { + AMREX_IF_ON_DEVICE(( + AMREX_DEVICE_PRINTF(" (%ld) is out of bound (%ld:%ld)\n", + i, begin, end-1); + amrex::Abort(); + )) + } else if constexpr (std::is_same_v) { + AMREX_IF_ON_DEVICE(( + AMREX_DEVICE_PRINTF(" (%lld) is out of bound (%lld:%lld)\n", + i, begin, end-1); + amrex::Abort(); + )) + } else { + AMREX_IF_ON_DEVICE(( amrex::Abort(); )) + } AMREX_IF_ON_HOST(( std::stringstream ss; ss << " (" << i << ") is out of bound (" diff --git a/Src/LinearSolvers/AMReX_AlgVector.H b/Src/LinearSolvers/AMReX_AlgVector.H index 92d5fe091d6..07c6fbffeaa 100644 --- a/Src/LinearSolvers/AMReX_AlgVector.H +++ b/Src/LinearSolvers/AMReX_AlgVector.H @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -13,51 +14,6 @@ namespace amrex { -template -struct VectorView -{ - T* AMREX_RESTRICT p = nullptr; - Long begin = 1; - Long end = 0; - - AMREX_GPU_HOST_DEVICE - explicit operator bool() const noexcept { return p != nullptr; } - - template ,int> = 0> - [[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE - U& operator[] (Long i) const noexcept { -#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK) - index_assert(i); -#endif - return p[i-begin]; - } - - [[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE - bool contains (Long i) const noexcept { return i>=begin && i > class AlgVector { @@ -94,18 +50,18 @@ public: [[nodiscard]] T * data () { return m_data.data(); } [[nodiscard]] AMREX_FORCE_INLINE - VectorView view () const { - return VectorView{m_data.data(), m_begin, m_end}; + Table1D view () const { + return Table1D{m_data.data(), m_begin, m_end}; } [[nodiscard]] AMREX_FORCE_INLINE - VectorView const_view () const { - return VectorView{m_data.data(), m_begin, m_end}; + Table1D const_view () const { + return Table1D{m_data.data(), m_begin, m_end}; } [[nodiscard]] AMREX_FORCE_INLINE - VectorView view () { - return VectorView{m_data.data(), m_begin, m_end}; + Table1D view () { + return Table1D{m_data.data(), m_begin, m_end}; } void setVal (T val); diff --git a/Src/LinearSolvers/AMReX_SpMV.H b/Src/LinearSolvers/AMReX_SpMV.H index 11d0f3412f5..4a13c2f8030 100644 --- a/Src/LinearSolvers/AMReX_SpMV.H +++ b/Src/LinearSolvers/AMReX_SpMV.H @@ -3,13 +3,11 @@ #include #include +#include #include #if defined(AMREX_USE_CUDA) # include -//# if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11) -//# include -//# endif #elif defined(AMREX_USE_HIP) # include #elif defined(AMREX_USE_DPCPP) @@ -36,34 +34,29 @@ void SpMV (AlgVector& y, SpMatrix const& A, AlgVector const& x) #if defined(AMREX_USE_GPU) Long const nrows = A.numLocalRows(); - Long const ncols = y.numLocalRows(); + Long const ncols = x.numLocalRows(); Long const nnz = A.numLocalNonZero(); - // y.setVal(0); - #if defined(AMREX_USE_CUDA) -#if 0 - - void* d_temp_storage = nullptr; - std::size_t temp_storage_bytes = 0; - cub::DeviceSpmv::CsrMV(d_temp_storage, temp_storage_bytes, (T*)mat, (AlgInt*)row, (AlgInt*)col, - (T*)px, (T*)py, nrows, ncols, nnz, Gpu::gpuStream()); - d_temp_storage = (void*)The_Arena()->alloc(temp_storage_bytes); - cub::DeviceSpmv::CsrMV(d_temp_storage, temp_storage_bytes, (T*)mat, (AlgInt*)row, (AlgInt*)col, - (T*)px, (T*)py, nrows, ncols, nnz, Gpu::gpuStream()); - Gpu::streamSynchronize(); - The_Arena()->free(d_temp_storage); - -#else - cusparseHandle_t handle; cusparseCreate(&handle); cusparseSetStream(handle, Gpu::gpuStream()); - cudaDataType data_type = (sizeof(T) == sizeof(double)) ? CUDA_R_64F : CUDA_R_32F; - cusparseIndexType_t index_type = (sizeof(AlgInt) == sizeof(int)) ? - CUSPARSE_INDEX_32I : CUSPARSE_INDEX_64I; + cudaDataType data_type; + if constexpr (std::is_same_v) { + data_type = CUDA_R_32F; + } else if constexpr (std::is_same_v) { + data_type = CUDA_R_64F; + } else if constexpr (std::is_same_v>) { + data_type = CUDA_C_32F; + } else if constexpr (std::is_same_v>) { + data_type = CUDA_C_64F; + } else { + amrex::Abort("SpMV: unsupported data type"); + } + + cusparseIndexType_t index_type = CUSPARSE_INDEX_64I; cusparseSpMatDescr_t mat_descr; cusparseCreateCsr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col, (void*)mat, @@ -75,14 +68,14 @@ void SpMV (AlgVector& y, SpMatrix const& A, AlgVector const& x) cusparseDnVecDescr_t y_descr; cusparseCreateDnVec(&y_descr, nrows, (void*)py, data_type); - T alpha = T(1.0); - T beta = T(0.0); + T alpha = T(1); + T beta = T(0); std::size_t buffer_size; cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr, &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, &buffer_size); - void* pbuffer = (void*)The_Arena()->alloc(buffer_size); + auto* pbuffer = (void*)The_Arena()->alloc(buffer_size); cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr, &beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, pbuffer); @@ -95,17 +88,26 @@ void SpMV (AlgVector& y, SpMatrix const& A, AlgVector const& x) cusparseDestroy(handle); The_Arena()->free(pbuffer); -#endif - #elif defined(AMREX_USE_HIP) hipsparseHandle_t handle; hipsparseCreate(&handle); hipsparseSetStream(handle, Gpu::gpuStream()); - hipDataType data_type = (sizeof(T) == sizeof(double)) ? HIP_R_64F : HIP_R_32F; - hipsparseIndexType_t index_type = (sizeof(AlgInt) == sizeof(int)) ? - HIPSPARSE_INDEX_32I : HIPSPARSE_INDEX_64I; + hipDataType data_type; + if constexpr (std::is_same_v) { + data_type = HIP_R_32F; + } else if constexpr (std::is_same_v) { + data_type = HIP_R_64F; + } else if constexpr (std::is_same_v>) { + data_type = HIP_C_32F; + } else if constexpr (std::is_same_v>) { + data_type = HIP_C_64F; + } else { + amrex::Abort("SpMV: unsupported data type"); + } + + hipsparseIndexType_t index_type = HIPSPARSE_INDEX_64I; hipsparseSpMatDescr_t mat_descr; hipsparseCreateCsr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col, (void*)mat, @@ -141,7 +143,7 @@ void SpMV (AlgVector& y, SpMatrix const& A, AlgVector const& x) mkl::sparse::matrix_handle_t handle{}; mkl::sparse::set_csr_data(handle, nrows, ncols, mkl::index_base::zero, - (AlgInt*)row, (AlgInt*)col, (T*)mat); + (Long*)row, (Long*)col, (T*)mat); mkl::sparse::gemv(Gpu::Device::streamQueue(), mkl::transpose::nontrans, T(1), handle, px, T(0), py); diff --git a/Src/LinearSolvers/AMReX_SpMatrix.H b/Src/LinearSolvers/AMReX_SpMatrix.H index 610ca899e26..84dca089a75 100644 --- a/Src/LinearSolvers/AMReX_SpMatrix.H +++ b/Src/LinearSolvers/AMReX_SpMatrix.H @@ -57,19 +57,21 @@ public: template friend void SpMV(AlgVector& y, SpMatrix const& A, AlgVector const& x); -private: - + //! Private function, but public for cuda void define_doit (int nnz); - void startComm (AlgVector const& x); - void finishComm (AlgVector& y); - #ifdef AMREX_USE_MPI + //! Private function, but public for cuda void prepare_comm (); void pack_buffer (AlgVector const& v); void unpack_buffer (AlgVector& v); #endif +private: + + void startComm (AlgVector const& x); + void finishComm (AlgVector& y); + template using DVec = PODVector >; template