Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Support __half on pre-sm53
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 25, 2022
1 parent d3e9c5f commit acc4b74
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 11 deletions.
43 changes: 36 additions & 7 deletions cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,8 @@ private:
CUB_RUNTIME_FUNCTION
static __half ComputeScale(__half lower_level, __half upper_level, int bins)
{
return static_cast<__half>((static_cast<float>(upper_level) -
static_cast<float>(lower_level)) /
bins);
return __float2half(
(__half2float(upper_level) - __half2float(lower_level)) / bins);
}
#endif

Expand Down Expand Up @@ -303,15 +302,45 @@ public:
this->scale = double(1.0) / scale_;
}

template <typename T>
static __device__ __forceinline__ void
BinSelectImpl(T sample, T min, T max, T scale, int &bin, bool valid)
{
if (valid && (sample >= min) && (sample < max))
{
bin = static_cast<int>((sample - min) / scale);
}
}

// Method for converting samples to bin-ids
template <CacheLoadModifier LOAD_MODIFIER, typename _SampleT>
__host__ __device__ __forceinline__ void BinSelect(_SampleT sample, int &bin, bool valid)
__host__ __device__ __forceinline__ void BinSelect(_SampleT sample,
int &bin,
bool valid)
{
LevelT level_sample = (LevelT) sample;
BinSelectImpl(static_cast<LevelT>(sample),
min,
max,
scale,
bin,
valid);
}

if (valid && (level_sample >= min) && (level_sample < max))
bin = (int) ((level_sample - min) / scale);
#if defined(__CUDA_FP16_TYPES_EXIST__)
template <CacheLoadModifier LOAD_MODIFIER>
__device__ __forceinline__ void BinSelect(__half sample, int &bin, bool valid)
{
NV_IF_TARGET(NV_PROVIDES_SM_53,
(BinSelectImpl<__half>(sample,
min, max, scale,
bin, valid);),
(BinSelectImpl<float>(__half2float(sample),
__half2float(min),
__half2float(max),
__half2float(scale),
bin, valid);));
}
#endif

// Method for converting samples to bin-ids (float specialization)
template <CacheLoadModifier LOAD_MODIFIER>
Expand Down
33 changes: 33 additions & 0 deletions cub/thread/thread_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
#include <cub/util_type.cuh>
#include <cub/config.cuh>

#include <nv/target>

CUB_NAMESPACE_BEGIN


Expand Down Expand Up @@ -146,7 +148,38 @@ __device__ __forceinline__ OffsetT UpperBound(
}


#if defined(__CUDA_FP16_TYPES_EXIST__)
template <
typename InputIteratorT,
typename OffsetT>
__device__ __forceinline__ OffsetT UpperBound(
InputIteratorT input, ///< [in] Input sequence
OffsetT num_items, ///< [in] Input sequence length
__half val) ///< [in] Search key
{
OffsetT retval = 0;
while (num_items > 0)
{
OffsetT half = num_items >> 1;

bool lt;
NV_IF_TARGET(NV_PROVIDES_SM_53,
(lt = val < input[retval + half];),
(lt = __half2float(val) < __half2float(input[retval + half]);));

if (lt)
{
num_items = half;
}
else
{
retval = retval + (half + 1);
num_items = num_items - (half + 1);
}
}

return retval;
}
#endif

CUB_NAMESPACE_END
8 changes: 4 additions & 4 deletions test/test_device_histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ struct Dispatch<NUM_ACTIVE_CHANNELS, NUM_CHANNELS, CUB>
return error;
}

#if defined(TEST_HALF_T)
#if TEST_HALF_T
/**
* Dispatch to CUB multi histogram-range entrypoint
*/
Expand Down Expand Up @@ -212,7 +212,7 @@ struct Dispatch<NUM_ACTIVE_CHANNELS, NUM_CHANNELS, CUB>
return error;
}

#if defined(TEST_HALF_T)
#if TEST_HALF_T
/**
* Dispatch to CUB multi histogram-even entrypoint
*/
Expand Down Expand Up @@ -304,7 +304,7 @@ struct Dispatch<1, 1, CUB>
return error;
}

#if defined(TEST_HALF_T)
#if TEST_HALF_T
template <typename CounterT, typename OffsetT>
//CUB_RUNTIME_FUNCTION __forceinline__
static cudaError_t Range(
Expand Down Expand Up @@ -388,7 +388,7 @@ struct Dispatch<1, 1, CUB>
return error;
}

#if defined(TEST_HALF_T)
#if TEST_HALF_T
template <typename CounterT, typename OffsetT>
//CUB_RUNTIME_FUNCTION __forceinline__
static cudaError_t Even(
Expand Down

0 comments on commit acc4b74

Please sign in to comment.