Skip to content

Commit

Permalink
type map
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 22, 2024
1 parent 6c1ae5b commit 265cb47
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 12 deletions.
15 changes: 14 additions & 1 deletion accessor/cuda_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
#include "utils.hpp"


struct __half;


namespace gko {


class half;


namespace acc {
namespace detail {

Expand All @@ -27,6 +35,11 @@ struct cuda_type {
using type = T;
};

template <>
struct cuda_type<gko::half> {
using type = __half;
};

// Unpack cv and reference / pointer qualifiers
template <typename T>
struct cuda_type<const T> {
Expand Down Expand Up @@ -57,7 +70,7 @@ struct cuda_type<T&&> {
// Transform std::complex to thrust::complex
template <typename T>
struct cuda_type<std::complex<T>> {
using type = thrust::complex<T>;
using type = thrust::complex<typename cuda_type<T>::type>;
};


Expand Down
14 changes: 13 additions & 1 deletion accessor/hip_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
#include "utils.hpp"


struct __half;


namespace gko {


class half;


namespace acc {
namespace detail {

Expand Down Expand Up @@ -53,11 +61,15 @@ struct hip_type<T&&> {
using type = typename hip_type<T>::type&&;
};

template <>
struct hip_type<gko::half> {
using type = __half;
};

// Transform std::complex to thrust::complex
template <typename T>
struct hip_type<std::complex<T>> {
using type = thrust::complex<T>;
using type = thrust::complex<typename hip_type<T>::type>;
};


Expand Down
36 changes: 30 additions & 6 deletions cuda/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,17 @@
#include <cusparse.h>
#include <thrust/complex.h>

#include <ginkgo/core/base/half.hpp>
#include <ginkgo/core/base/matrix_data.hpp>
#include <ginkgo/core/base/types.hpp>


namespace gko {


namespace kernels {
namespace cuda {


namespace detail {


/**
* @internal
*
Expand Down Expand Up @@ -124,6 +121,17 @@ struct culibs_type_impl<std::complex<double>> {
using type = cuDoubleComplex;
};


template <>
struct culibs_type_impl<half> {
using type = __half;
};

template <>
struct culibs_type_impl<std::complex<half>> {
using type = __half2;
};

template <typename T>
struct culibs_type_impl<thrust::complex<T>> {
using type = typename culibs_type_impl<std::complex<T>>::type;
Expand Down Expand Up @@ -154,9 +162,14 @@ struct cuda_type_impl<volatile T> {
using type = volatile typename cuda_type_impl<T>::type;
};

template <>
struct cuda_type_impl<half> {
using type = __half;
};

template <typename T>
struct cuda_type_impl<std::complex<T>> {
using type = thrust::complex<T>;
using type = thrust::complex<typename cuda_type_impl<T>::type>;
};

template <>
Expand All @@ -169,14 +182,24 @@ struct cuda_type_impl<cuComplex> {
using type = thrust::complex<float>;
};

template <>
struct cuda_type_impl<__half2> {
using type = thrust::complex<__half>;
};

template <typename T>
struct cuda_struct_member_type_impl {
using type = T;
};

template <typename T>
struct cuda_struct_member_type_impl<std::complex<T>> {
using type = fake_complex<T>;
using type = fake_complex<typename cuda_struct_member_type_impl<T>::type>;
};

template <>
struct cuda_struct_member_type_impl<gko::half> {
using type = __half;
};

template <typename ValueType, typename IndexType>
Expand All @@ -200,6 +223,7 @@ GKO_CUDA_DATA_TYPE(float, CUDA_R_32F);
GKO_CUDA_DATA_TYPE(double, CUDA_R_64F);
GKO_CUDA_DATA_TYPE(std::complex<float>, CUDA_C_32F);
GKO_CUDA_DATA_TYPE(std::complex<double>, CUDA_C_64F);
GKO_CUDA_DATA_TYPE(std::complex<float16>, CUDA_C_16F);
GKO_CUDA_DATA_TYPE(int32, CUDA_R_32I);
GKO_CUDA_DATA_TYPE(int8, CUDA_R_8I);

Expand Down
33 changes: 29 additions & 4 deletions hip/base/types.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
#endif
#include <thrust/complex.h>

#include <ginkgo/core/base/half.hpp>
#include <ginkgo/core/base/matrix_data.hpp>

#include "common/cuda_hip/base/runtime.hpp"


namespace gko {


namespace kernels {
namespace hip {
namespace detail {
Expand Down Expand Up @@ -130,6 +129,17 @@ struct hiplibs_type_impl<std::complex<double>> {
using type = hipDoubleComplex;
};

template <>
struct hiplibs_type_impl<half> {
using type = __half;
};

template <>
struct hiplibs_type_impl<std::complex<half>> {
using type = __half2;
};


template <typename T>
struct hiplibs_type_impl<thrust::complex<T>> {
using type = typename hiplibs_type_impl<std::complex<T>>::type;
Expand Down Expand Up @@ -202,9 +212,14 @@ struct hip_type_impl<volatile T> {
using type = volatile typename hip_type_impl<T>::type;
};

template <>
struct hip_type_impl<gko::half> {
using type = __half;
};

template <typename T>
struct hip_type_impl<std::complex<T>> {
using type = thrust::complex<T>;
using type = thrust::complex<typename hip_type_impl<T>::type>;
};

template <>
Expand All @@ -217,14 +232,24 @@ struct hip_type_impl<hipComplex> {
using type = thrust::complex<float>;
};

template <>
struct hip_type_impl<__half2> {
using type = thrust::complex<__half>;
};

template <typename T>
struct hip_struct_member_type_impl {
using type = T;
};

template <typename T>
struct hip_struct_member_type_impl<std::complex<T>> {
using type = fake_complex<T>;
using type = fake_complex<typename hip_struct_member_type_impl<T>::type>;
};

template <>
struct hip_struct_member_type_impl<gko::half> {
using type = __half;
};

template <typename ValueType, typename IndexType>
Expand Down

0 comments on commit 265cb47

Please sign in to comment.