Skip to content

Commit

Permalink
Merge hipSPARSE bindings for complex types
Browse files Browse the repository at this point in the history
This adds hipSPARSE bindings for `complex<float>` and `complex<double>`.

Related PR: #1538
  • Loading branch information
upsj authored Feb 13, 2024
2 parents 6176550 + a9b7c34 commit 435f184
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions hip/base/hipsparse_bindings.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ struct is_supported<float, int32> : std::true_type {};
template <>
struct is_supported<double, int32> : std::true_type {};

template <>
struct is_supported<std::complex<float>, int32> : std::true_type {};

template <>
struct is_supported<std::complex<double>, int32> : std::true_type {};

#define GKO_BIND_HIPSPARSE32_SPMV(ValueType, HipsparseName) \
inline void spmv(hipsparseHandle_t handle, hipsparseOperation_t transA, \
Expand Down Expand Up @@ -88,8 +93,12 @@ struct is_supported<double, int32> : std::true_type {};

GKO_BIND_HIPSPARSE32_SPMV(float, hipsparseScsrmv);
GKO_BIND_HIPSPARSE32_SPMV(double, hipsparseDcsrmv);
GKO_BIND_HIPSPARSE32_SPMV(std::complex<float>, hipsparseCcsrmv);
GKO_BIND_HIPSPARSE32_SPMV(std::complex<double>, hipsparseZcsrmv);
GKO_BIND_HIPSPARSE64_SPMV(float, hipsparseScsrmv);
GKO_BIND_HIPSPARSE64_SPMV(double, hipsparseDcsrmv);
GKO_BIND_HIPSPARSE64_SPMV(std::complex<float>, hipsparseCcsrmv);
GKO_BIND_HIPSPARSE64_SPMV(std::complex<double>, hipsparseZcsrmv);
template <typename ValueType>
GKO_BIND_HIPSPARSE32_SPMV(ValueType, detail::not_implemented);
template <typename ValueType>
Expand Down Expand Up @@ -132,8 +141,12 @@ GKO_BIND_HIPSPARSE64_SPMV(ValueType, detail::not_implemented);

GKO_BIND_HIPSPARSE32_SPMM(float, hipsparseScsrmm);
GKO_BIND_HIPSPARSE32_SPMM(double, hipsparseDcsrmm);
GKO_BIND_HIPSPARSE32_SPMM(std::complex<float>, hipsparseCcsrmm);
GKO_BIND_HIPSPARSE32_SPMM(std::complex<double>, hipsparseZcsrmm);
GKO_BIND_HIPSPARSE64_SPMM(float, hipsparseScsrmm);
GKO_BIND_HIPSPARSE64_SPMM(double, hipsparseDcsrmm);
GKO_BIND_HIPSPARSE64_SPMM(std::complex<float>, hipsparseCcsrmm);
GKO_BIND_HIPSPARSE64_SPMM(std::complex<double>, hipsparseZcsrmm);
template <typename ValueType>
GKO_BIND_HIPSPARSE32_SPMM(ValueType, detail::not_implemented);
template <typename ValueType>
Expand All @@ -160,6 +173,8 @@ GKO_BIND_HIPSPARSE64_SPMM(ValueType, detail::not_implemented);

GKO_BIND_HIPSPARSE32_SPMV(float, hipsparseShybmv);
GKO_BIND_HIPSPARSE32_SPMV(double, hipsparseDhybmv);
GKO_BIND_HIPSPARSE32_SPMV(std::complex<float>, hipsparseChybmv);
GKO_BIND_HIPSPARSE32_SPMV(std::complex<double>, hipsparseZhybmv);
template <typename ValueType>
GKO_BIND_HIPSPARSE32_SPMV(ValueType, detail::not_implemented);

Expand Down Expand Up @@ -369,8 +384,12 @@ GKO_BIND_HIPSPARSE64_CSR2HYB(ValueType, detail::not_implemented);

GKO_BIND_HIPSPARSE_TRANSPOSE32(float, hipsparseScsr2csc);
GKO_BIND_HIPSPARSE_TRANSPOSE32(double, hipsparseDcsr2csc);
GKO_BIND_HIPSPARSE_TRANSPOSE32(std::complex<float>, hipsparseCcsr2csc);
GKO_BIND_HIPSPARSE_TRANSPOSE32(std::complex<double>, hipsparseZcsr2csc);
GKO_BIND_HIPSPARSE_TRANSPOSE64(float, hipsparseScsr2csc);
GKO_BIND_HIPSPARSE_TRANSPOSE64(double, hipsparseDcsr2csc);
GKO_BIND_HIPSPARSE_TRANSPOSE64(std::complex<float>, hipsparseCcsr2csc);
GKO_BIND_HIPSPARSE_TRANSPOSE64(std::complex<double>, hipsparseZcsr2csc);
template <typename ValueType>
GKO_BIND_HIPSPARSE_TRANSPOSE32(ValueType, detail::not_implemented);
template <typename ValueType>
Expand Down Expand Up @@ -402,8 +421,12 @@ GKO_BIND_HIPSPARSE_TRANSPOSE64(ValueType, detail::not_implemented);

GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(float, hipsparseScsr2csc);
GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(double, hipsparseDcsr2csc);
GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(std::complex<float>, hipsparseCcsr2csc);
GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(std::complex<double>, hipsparseZcsr2csc);
GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(float, hipsparseScsr2csc);
GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(double, hipsparseDcsr2csc);
GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(std::complex<float>, hipsparseCcsr2csc);
GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(std::complex<double>, hipsparseZcsr2csc);
template <typename ValueType>
GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(ValueType, detail::not_implemented);
template <typename ValueType>
Expand Down Expand Up @@ -443,8 +466,16 @@ GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(ValueType, detail::not_implemented);

GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(float, hipsparseScsrsv2_bufferSize);
GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(double, hipsparseDcsrsv2_bufferSize);
GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(std::complex<float>,
hipsparseCcsrsv2_bufferSize);
GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(std::complex<double>,
hipsparseZcsrsv2_bufferSize);
GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(float, hipsparseScsrsv2_bufferSize);
GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(double, hipsparseDcsrsv2_bufferSize);
GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(std::complex<float>,
hipsparseCcsrsv2_bufferSize);
GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(std::complex<double>,
hipsparseZcsrsv2_bufferSize);
template <typename ValueType>
GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(ValueType, detail::not_implemented);
template <typename ValueType>
Expand Down Expand Up @@ -483,8 +514,16 @@ GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(ValueType, detail::not_implemented);

GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(float, hipsparseScsrsv2_analysis);
GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(double, hipsparseDcsrsv2_analysis);
GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(std::complex<float>,
hipsparseCcsrsv2_analysis);
GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(std::complex<double>,
hipsparseZcsrsv2_analysis);
GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(float, hipsparseScsrsv2_analysis);
GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(double, hipsparseDcsrsv2_analysis);
GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(std::complex<float>,
hipsparseCcsrsv2_analysis);
GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(std::complex<double>,
hipsparseZcsrsv2_analysis);
template <typename ValueType>
GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(ValueType, detail::not_implemented);
template <typename ValueType>
Expand Down Expand Up @@ -525,8 +564,12 @@ GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(ValueType, detail::not_implemented);

GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(float, hipsparseScsrsv2_solve);
GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(double, hipsparseDcsrsv2_solve);
GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(std::complex<float>, hipsparseCcsrsv2_solve);
GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(std::complex<double>, hipsparseZcsrsv2_solve);
GKO_BIND_HIPSPARSE64_CSRSV2_SOLVE(float, hipsparseScsrsv2_solve);
GKO_BIND_HIPSPARSE64_CSRSV2_SOLVE(double, hipsparseDcsrsv2_solve);
GKO_BIND_HIPSPARSE64_CSRSV2_SOLVE(std::complex<float>, hipsparseCcsrsv2_solve);
GKO_BIND_HIPSPARSE64_CSRSV2_SOLVE(std::complex<double>, hipsparseZcsrsv2_solve);
template <typename ValueType>
GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(ValueType, detail::not_implemented);
template <typename ValueType>
Expand Down

0 comments on commit 435f184

Please sign in to comment.