diff --git a/hip/base/hipsparse_bindings.hip.hpp b/hip/base/hipsparse_bindings.hip.hpp index 0a6a3dad23f..62c7e60995e 100644 --- a/hip/base/hipsparse_bindings.hip.hpp +++ b/hip/base/hipsparse_bindings.hip.hpp @@ -57,6 +57,11 @@ struct is_supported : std::true_type {}; template <> struct is_supported : std::true_type {}; +template <> +struct is_supported, int32> : std::true_type {}; + +template <> +struct is_supported, int32> : std::true_type {}; #define GKO_BIND_HIPSPARSE32_SPMV(ValueType, HipsparseName) \ inline void spmv(hipsparseHandle_t handle, hipsparseOperation_t transA, \ @@ -88,8 +93,12 @@ struct is_supported : std::true_type {}; GKO_BIND_HIPSPARSE32_SPMV(float, hipsparseScsrmv); GKO_BIND_HIPSPARSE32_SPMV(double, hipsparseDcsrmv); +GKO_BIND_HIPSPARSE32_SPMV(std::complex, hipsparseCcsrmv); +GKO_BIND_HIPSPARSE32_SPMV(std::complex, hipsparseZcsrmv); GKO_BIND_HIPSPARSE64_SPMV(float, hipsparseScsrmv); GKO_BIND_HIPSPARSE64_SPMV(double, hipsparseDcsrmv); +GKO_BIND_HIPSPARSE64_SPMV(std::complex, hipsparseCcsrmv); +GKO_BIND_HIPSPARSE64_SPMV(std::complex, hipsparseZcsrmv); template GKO_BIND_HIPSPARSE32_SPMV(ValueType, detail::not_implemented); template @@ -132,8 +141,12 @@ GKO_BIND_HIPSPARSE64_SPMV(ValueType, detail::not_implemented); GKO_BIND_HIPSPARSE32_SPMM(float, hipsparseScsrmm); GKO_BIND_HIPSPARSE32_SPMM(double, hipsparseDcsrmm); +GKO_BIND_HIPSPARSE32_SPMM(std::complex, hipsparseCcsrmm); +GKO_BIND_HIPSPARSE32_SPMM(std::complex, hipsparseZcsrmm); GKO_BIND_HIPSPARSE64_SPMM(float, hipsparseScsrmm); GKO_BIND_HIPSPARSE64_SPMM(double, hipsparseDcsrmm); +GKO_BIND_HIPSPARSE64_SPMM(std::complex, hipsparseCcsrmm); +GKO_BIND_HIPSPARSE64_SPMM(std::complex, hipsparseZcsrmm); template GKO_BIND_HIPSPARSE32_SPMM(ValueType, detail::not_implemented); template @@ -160,6 +173,8 @@ GKO_BIND_HIPSPARSE64_SPMM(ValueType, detail::not_implemented); GKO_BIND_HIPSPARSE32_SPMV(float, hipsparseShybmv); GKO_BIND_HIPSPARSE32_SPMV(double, hipsparseDhybmv); +GKO_BIND_HIPSPARSE32_SPMV(std::complex, hipsparseChybmv); +GKO_BIND_HIPSPARSE32_SPMV(std::complex, hipsparseZhybmv); template GKO_BIND_HIPSPARSE32_SPMV(ValueType, detail::not_implemented); @@ -369,8 +384,12 @@ GKO_BIND_HIPSPARSE64_CSR2HYB(ValueType, detail::not_implemented); GKO_BIND_HIPSPARSE_TRANSPOSE32(float, hipsparseScsr2csc); GKO_BIND_HIPSPARSE_TRANSPOSE32(double, hipsparseDcsr2csc); +GKO_BIND_HIPSPARSE_TRANSPOSE32(std::complex, hipsparseCcsr2csc); +GKO_BIND_HIPSPARSE_TRANSPOSE32(std::complex, hipsparseZcsr2csc); GKO_BIND_HIPSPARSE_TRANSPOSE64(float, hipsparseScsr2csc); GKO_BIND_HIPSPARSE_TRANSPOSE64(double, hipsparseDcsr2csc); +GKO_BIND_HIPSPARSE_TRANSPOSE64(std::complex, hipsparseCcsr2csc); +GKO_BIND_HIPSPARSE_TRANSPOSE64(std::complex, hipsparseZcsr2csc); template GKO_BIND_HIPSPARSE_TRANSPOSE32(ValueType, detail::not_implemented); template @@ -402,8 +421,12 @@ GKO_BIND_HIPSPARSE_TRANSPOSE64(ValueType, detail::not_implemented); GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(float, hipsparseScsr2csc); GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(double, hipsparseDcsr2csc); +GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(std::complex, hipsparseCcsr2csc); +GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(std::complex, hipsparseZcsr2csc); GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(float, hipsparseScsr2csc); GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(double, hipsparseDcsr2csc); +GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(std::complex, hipsparseCcsr2csc); +GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(std::complex, hipsparseZcsr2csc); template GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE32(ValueType, detail::not_implemented); template @@ -443,8 +466,16 @@ GKO_BIND_HIPSPARSE_CONJ_TRANSPOSE64(ValueType, detail::not_implemented); GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(float, hipsparseScsrsv2_bufferSize); GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(double, hipsparseDcsrsv2_bufferSize); +GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(std::complex, + hipsparseCcsrsv2_bufferSize); +GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(std::complex, + hipsparseZcsrsv2_bufferSize); GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(float, hipsparseScsrsv2_bufferSize); GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(double, hipsparseDcsrsv2_bufferSize); +GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(std::complex, + hipsparseCcsrsv2_bufferSize); +GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(std::complex, + hipsparseZcsrsv2_bufferSize); template GKO_BIND_HIPSPARSE32_CSRSV2_BUFFERSIZE(ValueType, detail::not_implemented); template @@ -483,8 +514,16 @@ GKO_BIND_HIPSPARSE64_CSRSV2_BUFFERSIZE(ValueType, detail::not_implemented); GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(float, hipsparseScsrsv2_analysis); GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(double, hipsparseDcsrsv2_analysis); +GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(std::complex, + hipsparseCcsrsv2_analysis); +GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(std::complex, + hipsparseZcsrsv2_analysis); GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(float, hipsparseScsrsv2_analysis); GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(double, hipsparseDcsrsv2_analysis); +GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(std::complex, + hipsparseCcsrsv2_analysis); +GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(std::complex, + hipsparseZcsrsv2_analysis); template GKO_BIND_HIPSPARSE32_CSRSV2_ANALYSIS(ValueType, detail::not_implemented); template @@ -525,8 +564,12 @@ GKO_BIND_HIPSPARSE64_CSRSV2_ANALYSIS(ValueType, detail::not_implemented); GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(float, hipsparseScsrsv2_solve); GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(double, hipsparseDcsrsv2_solve); +GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(std::complex, hipsparseCcsrsv2_solve); +GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(std::complex, hipsparseZcsrsv2_solve); GKO_BIND_HIPSPARSE64_CSRSV2_SOLVE(float, hipsparseScsrsv2_solve); GKO_BIND_HIPSPARSE64_CSRSV2_SOLVE(double, hipsparseDcsrsv2_solve); +GKO_BIND_HIPSPARSE64_CSRSV2_SOLVE(std::complex, hipsparseCcsrsv2_solve); +GKO_BIND_HIPSPARSE64_CSRSV2_SOLVE(std::complex, hipsparseZcsrsv2_solve); template GKO_BIND_HIPSPARSE32_CSRSV2_SOLVE(ValueType, detail::not_implemented); template