Skip to content

Commit

Permalink
CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Dec 7, 2024
1 parent 5872a04 commit 610b40f
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 163 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/dependencies/dependencies_nvcc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ sudo apt-get install -y \
cuda-nvml-dev-$VERSION_DASHED \
cuda-nvtx-$VERSION_DASHED \
libcufft-dev-$VERSION_DASHED \
libcurand-dev-$VERSION_DASHED
libcurand-dev-$VERSION_DASHED \
libcusparse-dev-$VERSION_DASHED
sudo ln -s cuda-$VERSION_DOTTED /usr/local/cuda
38 changes: 27 additions & 11 deletions Src/Base/AMReX_TableData.H
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

namespace amrex {

template <typename T>
template <typename T, typename IDX = int>
struct Table1D
{
T* AMREX_RESTRICT p = nullptr;
int begin = 1;
int end = 0;
IDX begin = 1;
IDX end = 0;

constexpr Table1D () noexcept = default;

Expand All @@ -33,7 +33,7 @@ struct Table1D
{}

AMREX_GPU_HOST_DEVICE
constexpr Table1D (T* a_p, int a_begin, int a_end) noexcept
constexpr Table1D (T* a_p, IDX a_begin, IDX a_end) noexcept
: p(a_p),
begin(a_begin),
end(a_end)
Expand All @@ -44,7 +44,7 @@ struct Table1D

template <class U=T, std::enable_if_t<!std::is_void_v<U>,int> = 0>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
U& operator() (int i) const noexcept {
U& operator() (IDX i) const noexcept {
#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
index_assert(i);
#endif
Expand All @@ -53,14 +53,30 @@ struct Table1D

#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
AMREX_GPU_HOST_DEVICE inline
void index_assert (int i) const
void index_assert (IDX i) const
{
if (i < begin || i >= end) {
AMREX_IF_ON_DEVICE((
AMREX_DEVICE_PRINTF(" (%d) is out of bound (%d:%d)\n",
i, begin, end-1);
amrex::Abort();
))
if constexpr (std::is_same_v<IDX,int>) {
AMREX_IF_ON_DEVICE((
AMREX_DEVICE_PRINTF(" (%d) is out of bound (%d:%d)\n",
i, begin, end-1);
amrex::Abort();
))
} else if constexpr (std::is_same_v<IDX,long>) {
AMREX_IF_ON_DEVICE((
AMREX_DEVICE_PRINTF(" (%ld) is out of bound (%ld:%ld)\n",
i, begin, end-1);
amrex::Abort();
))
} else if constexpr (std::is_same_v<IDX,long long>) {
AMREX_IF_ON_DEVICE((
AMREX_DEVICE_PRINTF(" (%lld) is out of bound (%lld:%lld)\n",
i, begin, end-1);
amrex::Abort();
))
} else {
AMREX_IF_ON_DEVICE(( amrex::Abort(); ))
}
AMREX_IF_ON_HOST((
std::stringstream ss;
ss << " (" << i << ") is out of bound ("
Expand Down
58 changes: 7 additions & 51 deletions Src/LinearSolvers/AMReX_AlgVector.H
Original file line number Diff line number Diff line change
Expand Up @@ -6,58 +6,14 @@
#include <AMReX_FabArray.H>
#include <AMReX_INT.H>
#include <AMReX_LayoutData.H>
#include <AMReX_TableData.H>

#include <fstream>
#include <string>
#include <type_traits>

namespace amrex {

template <typename T>
struct VectorView
{
T* AMREX_RESTRICT p = nullptr;
Long begin = 1;
Long end = 0;

AMREX_GPU_HOST_DEVICE
explicit operator bool() const noexcept { return p != nullptr; }

template <class U=T, std::enable_if_t<!std::is_void_v<U>,int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
U& operator[] (Long i) const noexcept {
#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
index_assert(i);
#endif
return p[i-begin];
}

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
bool contains (Long i) const noexcept { return i>=begin && i<end; }

#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
AMREX_GPU_HOST_DEVICE inline
void index_assert (Long i) const
{
if ( ! contains(i)) {
AMREX_IF_ON_DEVICE((
// xxxxx TODO: Yes, this is wrong on Windows.
AMREX_DEVICE_PRINTF(" [%ld] is out of bound [%ld:%ld]\n",
i, begin, end-1);
amrex::Abort();
))
AMREX_IF_ON_HOST((
std::stringstream ss;
ss << " [" << i << "] is out of bound ["
<< begin << ":" << end-1 << "]";
amrex::Abort(ss.str());
))
}
}
#endif

};

template <typename T, typename Allocator = DefaultAllocator<T> >
class AlgVector
{
Expand Down Expand Up @@ -94,18 +50,18 @@ public:
[[nodiscard]] T * data () { return m_data.data(); }

[[nodiscard]] AMREX_FORCE_INLINE
VectorView<T const> view () const {
return VectorView<T const>{m_data.data(), m_begin, m_end};
Table1D<T const, Long> view () const {
return Table1D<T const, Long>{m_data.data(), m_begin, m_end};
}

[[nodiscard]] AMREX_FORCE_INLINE
VectorView<T const> const_view () const {
return VectorView<T const>{m_data.data(), m_begin, m_end};
Table1D<T const, Long> const_view () const {
return Table1D<T const, Long>{m_data.data(), m_begin, m_end};
}

[[nodiscard]] AMREX_FORCE_INLINE
VectorView<T> view () {
return VectorView<T>{m_data.data(), m_begin, m_end};
Table1D<T, Long> view () {
return Table1D<T, Long>{m_data.data(), m_begin, m_end};
}

void setVal (T val);
Expand Down
66 changes: 34 additions & 32 deletions Src/LinearSolvers/AMReX_SpMV.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
#include <AMReX_Config.H>

#include <AMReX_AlgVector.H>
#include <AMReX_GpuComplex.H>
#include <AMReX_SpMatrix.H>

#if defined(AMREX_USE_CUDA)
# include <cusparse.h>
//# if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
//# include <cub/cub.cuh>
//# endif
#elif defined(AMREX_USE_HIP)
# include <hipsparse.h>
#elif defined(AMREX_USE_DPCPP)
Expand All @@ -36,34 +34,29 @@ void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)
#if defined(AMREX_USE_GPU)

Long const nrows = A.numLocalRows();
Long const ncols = y.numLocalRows();
Long const ncols = x.numLocalRows();
Long const nnz = A.numLocalNonZero();

// y.setVal(0);

#if defined(AMREX_USE_CUDA)

#if 0

void* d_temp_storage = nullptr;
std::size_t temp_storage_bytes = 0;
cub::DeviceSpmv::CsrMV(d_temp_storage, temp_storage_bytes, (T*)mat, (AlgInt*)row, (AlgInt*)col,
(T*)px, (T*)py, nrows, ncols, nnz, Gpu::gpuStream());
d_temp_storage = (void*)The_Arena()->alloc(temp_storage_bytes);
cub::DeviceSpmv::CsrMV(d_temp_storage, temp_storage_bytes, (T*)mat, (AlgInt*)row, (AlgInt*)col,
(T*)px, (T*)py, nrows, ncols, nnz, Gpu::gpuStream());
Gpu::streamSynchronize();
The_Arena()->free(d_temp_storage);

#else

cusparseHandle_t handle;
cusparseCreate(&handle);
cusparseSetStream(handle, Gpu::gpuStream());

cudaDataType data_type = (sizeof(T) == sizeof(double)) ? CUDA_R_64F : CUDA_R_32F;
cusparseIndexType_t index_type = (sizeof(AlgInt) == sizeof(int)) ?
CUSPARSE_INDEX_32I : CUSPARSE_INDEX_64I;
cudaDataType data_type;
if constexpr (std::is_same_v<T,float>) {
data_type = CUDA_R_32F;
} else if constexpr (std::is_same_v<T,double>) {
data_type = CUDA_R_64F;
} else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
data_type = CUDA_C_32F;
} else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
data_type = CUDA_C_64F;
} else {
amrex::Abort("SpMV: unsupported data type");
}

cusparseIndexType_t index_type = CUSPARSE_INDEX_64I;

cusparseSpMatDescr_t mat_descr;
cusparseCreateCsr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col, (void*)mat,
Expand All @@ -75,14 +68,14 @@ void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)
cusparseDnVecDescr_t y_descr;
cusparseCreateDnVec(&y_descr, nrows, (void*)py, data_type);

T alpha = T(1.0);
T beta = T(0.0);
T alpha = T(1);
T beta = T(0);

std::size_t buffer_size;
cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
&beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, &buffer_size);

void* pbuffer = (void*)The_Arena()->alloc(buffer_size);
auto* pbuffer = (void*)The_Arena()->alloc(buffer_size);

cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_descr, x_descr,
&beta, y_descr, data_type, CUSPARSE_SPMV_ALG_DEFAULT, pbuffer);
Expand All @@ -95,17 +88,26 @@ void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)
cusparseDestroy(handle);
The_Arena()->free(pbuffer);

#endif

#elif defined(AMREX_USE_HIP)

hipsparseHandle_t handle;
hipsparseCreate(&handle);
hipsparseSetStream(handle, Gpu::gpuStream());

hipDataType data_type = (sizeof(T) == sizeof(double)) ? HIP_R_64F : HIP_R_32F;
hipsparseIndexType_t index_type = (sizeof(AlgInt) == sizeof(int)) ?
HIPSPARSE_INDEX_32I : HIPSPARSE_INDEX_64I;
hipDataType data_type;
if constexpr (std::is_same_v<T,float>) {
data_type = HIP_R_32F;
} else if constexpr (std::is_same_v<T,double>) {
data_type = HIP_R_64F;
} else if constexpr (std::is_same_v<T,GpuComplex<float>>) {
data_type = HIP_C_32F;
} else if constexpr (std::is_same_v<T,GpuComplex<double>>) {
data_type = HIP_C_64F;
} else {
amrex::Abort("SpMV: unsupported data type");
}

hipsparseIndexType_t index_type = HIPSPARSE_INDEX_64I;

hipsparseSpMatDescr_t mat_descr;
hipsparseCreateCsr(&mat_descr, nrows, ncols, nnz, (void*)row, (void*)col, (void*)mat,
Expand Down Expand Up @@ -141,7 +143,7 @@ void SpMV (AlgVector<T>& y, SpMatrix<T> const& A, AlgVector<T> const& x)

mkl::sparse::matrix_handle_t handle{};
mkl::sparse::set_csr_data(handle, nrows, ncols, mkl::index_base::zero,
(AlgInt*)row, (AlgInt*)col, (T*)mat);
(Long*)row, (Long*)col, (T*)mat);
mkl::sparse::gemv(Gpu::Device::streamQueue(), mkl::transpose::nontrans,
T(1), handle, px, T(0), py);

Expand Down
Loading

0 comments on commit 610b40f

Please sign in to comment.