Skip to content

Commit

Permalink
added functionality to use load/store intrinsics via template params …
Browse files Browse the repository at this point in the history
…to try to influence caching behavior.
  • Loading branch information
bHimes committed Dec 28, 2024
1 parent f39f2d7 commit 2fd9d99
Showing 1 changed file with 58 additions and 24 deletions.
82 changes: 58 additions & 24 deletions include/FastFFT.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,40 @@ __global__ void clip_into_real_kernel(InputType* real_values_gpu,
int3 wanted_coordinate_of_box_center,
OutputBaseType wanted_padding_value);

template <unsigned int hint_type, typename T>
__device__ __forceinline__ T load_with_hint(const T* ptr, const int idx) {
static_assert(hint_type == 0 || hint_type == 1 || hint_type == 2 || hint_type == 3, "invalid mem op hint type");
// 0 is default for store or load
// Cache at all levels, likely to be accessed again. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators
// Cache write-back all coherent levels.
if constexpr ( hint_type == 0 )
return __ldca(&ptr[idx]);
else if constexpr ( hint_type == 1 )
return __ldcg(&ptr[idx]);
else if constexpr ( hint_type == 2 )
return __ldcs(&ptr[idx]);
else
return ptr[idx]; // when we are doing a store, we still want to convert type but there is no need for a load
};

template <unsigned int hint_type, typename T>
__device__ __forceinline__ void store_with_hint(T* ptr, const int idx, T val_to_store) {
static_assert(hint_type == 0 || hint_type == 1 || hint_type == 2, "invalid mem op hint type");
// 0 is default for store or load
// Cache at all levels, likely to be accessed again. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators
// Cache write-back all coherent levels.
if constexpr ( hint_type == 0 )
return __stwb(&ptr[idx]);
else if constexpr ( hint_type == 1 )
return __stcg(&ptr[idx]);
else
return __stcs(&ptr[idx]);
}

// TODO: This would be much cleaner if we could first go from complex_compute_t -> float 2 then do conversions
// I think since this would be a compile time decision, it would be fine, but it would be good to confirm.
template <class FFT, typename SetTo_t, typename GetFrom_t>
inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, const int idx) {
template <class FFT, typename SetTo_t, typename GetFrom_t, unsigned int load_hint = 3>
inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, const unsigned int idx) {
using complex_compute_t = typename FFT::value_type;
using scalar_compute_t = typename complex_compute_t::value_type;

Expand All @@ -245,38 +275,41 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c
if constexpr ( std::is_same_v<SetTo_t, scalar_compute_t> || std::is_same_v<SetTo_t, float> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(reinterpret_cast<const SetTo_t*>(ptr)[idx]);
return std::move(load_with_hint<load_hint>(reinterpret_cast<const SetTo_t*>(ptr), idx));
}
else if constexpr ( std::is_same_v<SetTo_t, __half> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(__float2half_rn(reinterpret_cast<const float*>(ptr)[idx]));
return std::move(__float2half_rn(load_with_hint<load_hint>(reinterpret_cast<const float*>(ptr), idx)));
}
else if constexpr ( std::is_same_v<SetTo_t, __half2> ) {
// Note: we will eventually need a similar hase for __nv_bfloat16
// I think I may need to strip the const news for this to work
if constexpr ( std::is_same_v<GetFrom_t, complex_compute_t> ) {
return std::move(__floats2half2_rn(ptr[idx].real( ), 0.f));
return std::move(__floats2half2_rn(load_with_hint<load_hint>(ptr, idx).real( ), 0.f));
}
else {
return std::move(__floats2half2_rn(static_cast<const float*>(ptr)[idx], 0.f));
return std::move(__floats2half2_rn(load_with_hint<load_hint>(static_cast<const float*>(ptr), idx), 0.f));
}
}
else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, complex_compute_t> && std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> ) {
// return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
return std::move(ptr[idx]);
float2 t = load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx); // FIXME
return std::move(complex_compute_t{t.x, t.y});
}
else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, complex_compute_t> && std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
return std::move(SetTo_t{ptr[idx].real( ), ptr[idx].imag( )});
// return std::move(SetTo_t{load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx).real( ), load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx).imag( )});
return std::move(SetTo_t{load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx).x, load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx).y});
}

else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, float2> && std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> ) {
// return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
return std::move(SetTo_t{ptr[idx].x, ptr[idx].y});
return std::move(SetTo_t{load_with_hint<load_hint>(ptr, idx).x, load_with_hint<load_hint>(ptr, idx).y});
}
else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, float2> && std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
return std::move(ptr[idx]);
return std::move(load_with_hint<load_hint>(ptr, idx));
}
else {
static_no_match( );
Expand All @@ -286,42 +319,43 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c
if constexpr ( std::is_same_v<SetTo_t, scalar_compute_t> || std::is_same_v<SetTo_t, float> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
return std::move(load_with_hint<load_hint>(static_cast<const SetTo_t*>(ptr), idx));
}
else if constexpr ( std::is_same_v<SetTo_t, __half> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(__float2half_rn(static_cast<const float*>(ptr)[idx]));
return std::move(__float2half_rn(load_with_hint<load_hint>(static_cast<const float*>(ptr), idx)));
}
else if constexpr ( std::is_same_v<SetTo_t, __half2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
return std::move(__floats2half2_rn(static_cast<const float*>(ptr)[idx], 0.f));
return std::move(__floats2half2_rn(load_with_hint<load_hint>(static_cast<const float*>(ptr), idx), 0.f));
}
else if constexpr ( std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> || std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
return std::move(SetTo_t{static_cast<const float*>(ptr)[idx], 0.f});

return std::move(SetTo_t{load_with_hint<load_hint>(static_cast<const float*>(ptr), idx), 0.f});
}
else {
static_no_match( );
}
}
else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, __half> ) {
if constexpr ( std::is_same_v<SetTo_t, scalar_compute_t> || std::is_same_v<SetTo_t, float> ) {
return std::move(__half2float(ptr[idx]));
return std::move(__half2float(load_with_hint<load_hint>(ptr, idx)));
}
else if constexpr ( std::is_same_v<SetTo_t, __half> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
return std::move(load_with_hint<load_hint>(static_cast<const SetTo_t*>(ptr), idx));
}
else if constexpr ( std::is_same_v<SetTo_t, __half2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
// FIXME: For some reason CUDART_ZERO_FP16 is not defined even with cuda_fp16.h included
return std::move(__halves2half2(static_cast<const __half*>(ptr)[idx], __ushort_as_half((unsigned short)0x0000U)));
return std::move(__halves2half2(load_with_hint<load_hint>(static_cast<const __half*>(ptr), idx), __ushort_as_half((unsigned short)0x0000U)));
}
else if constexpr ( std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> || std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
return std::move(SetTo_t{__half2float(static_cast<const __half*>(ptr)[idx]), 0.f});
return std::move(SetTo_t{__half2float(load_with_hint<load_hint>(static_cast<const __half*>(ptr), idx)), 0.f});
}
else {
static_no_match( );
Expand All @@ -331,16 +365,16 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c
if constexpr ( std::is_same_v<SetTo_t, scalar_compute_t> || std::is_same_v<SetTo_t, float> || std::is_same_v<SetTo_t, __half> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(reinterpret_cast<const SetTo_t*>(ptr)[idx]);
return std::move(load_with_hint<load_hint>(reinterpret_cast<const SetTo_t*>(ptr), idx));
}
else if constexpr ( std::is_same_v<SetTo_t, __half2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
// FIXME: For some reason CUDART_ZERO_FP16 is not defined even with cuda_fp16.h included
return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
return std::move(load_with_hint<load_hint>(static_cast<const SetTo_t*>(ptr), idx));
}
else if constexpr ( std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> || std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
return std::move(SetTo_t{__low2float(static_cast<const __half2*>(ptr)[idx]), __high2float(static_cast<const __half2*>(ptr)[idx])});
return std::move(SetTo_t{__low2float(load_with_hint<load_hint>(static_cast<const __half2*>(ptr), idx)), __high2float(load_with_hint<load_hint>(static_cast<const __half2*>(ptr), idx))});
}
else {
static_no_match( );
Expand Down Expand Up @@ -1056,7 +1090,7 @@ struct io {
for ( unsigned int i = 0; i < FFT::input_ept; i++ ) {

if ( x_prime < SignalLength )
smem_buffer[smem_idx] = convert_if_needed<FFT, complex_compute_t>(input, x_prime * pixel_pitch + fft_idx + blockIdx.y * n_coalesced_ffts);
smem_buffer[smem_idx] = convert_if_needed<FFT, complex_compute_t, data_io_t, 1>(input, x_prime * pixel_pitch + fft_idx + blockIdx.y * n_coalesced_ffts);
__syncthreads( );

if ( index < SignalLength )
Expand Down Expand Up @@ -1094,7 +1128,7 @@ struct io {

unsigned int index = threadIdx.x;
for ( unsigned int i = 0; i < FFT::elements_per_thread; i++ ) {
thread_data[i] = convert_if_needed<FFT, complex_compute_t>(input, index);
thread_data[i] = convert_if_needed<FFT, complex_compute_t, data_io_t, 2>(input, index);
index += FFT::stride;
}
}
Expand All @@ -1112,7 +1146,7 @@ struct io {
float2 temp;
for ( unsigned int i = 0; i < FFT::elements_per_thread; i++ ) {
if ( index < last_index_to_load ) {
temp = pre_op_functor(convert_if_needed<FFT, float2>(input, index));
temp = pre_op_functor(convert_if_needed<FFT, float2, data_io_t, 2>(input, index));
thread_data[i] = convert_if_needed<FFT, complex_compute_t>(&temp, 0);
}
else {
Expand Down

0 comments on commit 2fd9d99

Please sign in to comment.