diff --git a/accessor/cuda_helper.hpp b/accessor/cuda_helper.hpp index 31d3599516d..3efc6eb22b7 100644 --- a/accessor/cuda_helper.hpp +++ b/accessor/cuda_helper.hpp @@ -17,7 +17,15 @@ #include "utils.hpp" +struct __half; + + namespace gko { + + +class half; + + namespace acc { namespace detail { @@ -27,6 +35,11 @@ struct cuda_type { using type = T; }; +template <> +struct cuda_type { + using type = __half; +}; + // Unpack cv and reference / pointer qualifiers template struct cuda_type { @@ -57,7 +70,7 @@ struct cuda_type { // Transform std::complex to thrust::complex template struct cuda_type> { - using type = thrust::complex; + using type = thrust::complex::type>; }; diff --git a/accessor/hip_helper.hpp b/accessor/hip_helper.hpp index 6b76b726c10..8827fd6eb11 100644 --- a/accessor/hip_helper.hpp +++ b/accessor/hip_helper.hpp @@ -17,7 +17,15 @@ #include "utils.hpp" +struct __half; + + namespace gko { + + +class half; + + namespace acc { namespace detail { @@ -53,11 +61,15 @@ struct hip_type { using type = typename hip_type::type&&; }; +template <> +struct hip_type { + using type = __half; +}; // Transform std::complex to thrust::complex template struct hip_type> { - using type = thrust::complex; + using type = thrust::complex::type>; }; diff --git a/cuda/base/types.hpp b/cuda/base/types.hpp index 7252f7d673d..a2abdbce898 100644 --- a/cuda/base/types.hpp +++ b/cuda/base/types.hpp @@ -14,20 +14,17 @@ #include #include +#include #include #include namespace gko { - namespace kernels { namespace cuda { - - namespace detail { - /** * @internal * @@ -124,6 +121,17 @@ struct culibs_type_impl> { using type = cuDoubleComplex; }; + +template <> +struct culibs_type_impl { + using type = __half; +}; + +template <> +struct culibs_type_impl> { + using type = __half2; +}; + template struct culibs_type_impl> { using type = typename culibs_type_impl>::type; @@ -154,9 +162,14 @@ struct cuda_type_impl { using type = volatile typename cuda_type_impl::type; }; +template <> +struct cuda_type_impl { + using type = __half; +}; + template struct cuda_type_impl> { - using type = thrust::complex; + using type = thrust::complex::type>; }; template <> @@ -169,6 +182,11 @@ struct cuda_type_impl { using type = thrust::complex; }; +template <> +struct cuda_type_impl<__half2> { + using type = thrust::complex<__half>; +}; + template struct cuda_struct_member_type_impl { using type = T; @@ -176,7 +194,12 @@ struct cuda_struct_member_type_impl { template struct cuda_struct_member_type_impl> { - using type = fake_complex; + using type = fake_complex::type>; +}; + +template <> +struct cuda_struct_member_type_impl { + using type = __half; }; template @@ -200,6 +223,7 @@ GKO_CUDA_DATA_TYPE(float, CUDA_R_32F); GKO_CUDA_DATA_TYPE(double, CUDA_R_64F); GKO_CUDA_DATA_TYPE(std::complex, CUDA_C_32F); GKO_CUDA_DATA_TYPE(std::complex, CUDA_C_64F); +GKO_CUDA_DATA_TYPE(std::complex, CUDA_C_16F); GKO_CUDA_DATA_TYPE(int32, CUDA_R_32I); GKO_CUDA_DATA_TYPE(int8, CUDA_R_8I); diff --git a/hip/base/types.hip.hpp b/hip/base/types.hip.hpp index bb0d4a2d0c9..c3982b7562e 100644 --- a/hip/base/types.hip.hpp +++ b/hip/base/types.hip.hpp @@ -21,14 +21,13 @@ #endif #include +#include #include #include "common/cuda_hip/base/runtime.hpp" namespace gko { - - namespace kernels { namespace hip { namespace detail { @@ -130,6 +129,17 @@ struct hiplibs_type_impl> { using type = hipDoubleComplex; }; +template <> +struct hiplibs_type_impl { + using type = __half; +}; + +template <> +struct hiplibs_type_impl> { + using type = __half2; +}; + + template struct hiplibs_type_impl> { using type = typename hiplibs_type_impl>::type; @@ -202,9 +212,14 @@ struct hip_type_impl { using type = volatile typename hip_type_impl::type; }; +template <> +struct hip_type_impl { + using type = __half; +}; + template struct hip_type_impl> { - using type = thrust::complex; + using type = thrust::complex::type>; }; template <> @@ -217,6 +232,11 @@ struct hip_type_impl { using type = thrust::complex; }; +template <> +struct hip_type_impl<__half2> { + using type = thrust::complex<__half>; +}; + template struct hip_struct_member_type_impl { using type = T; @@ -224,7 +244,12 @@ struct hip_struct_member_type_impl { template struct hip_struct_member_type_impl> { - using type = fake_complex; + using type = fake_complex::type>; +}; + +template <> +struct hip_struct_member_type_impl { + using type = __half; }; template