diff --git a/include/FastFFT.cuh b/include/FastFFT.cuh index 27b9f71..4804d93 100644 --- a/include/FastFFT.cuh +++ b/include/FastFFT.cuh @@ -224,10 +224,40 @@ __global__ void clip_into_real_kernel(InputType* real_values_gpu, int3 wanted_coordinate_of_box_center, OutputBaseType wanted_padding_value); +template +__device__ __forceinline__ T load_with_hint(const T* ptr, const int idx) { + static_assert(hint_type == 0 || hint_type == 1 || hint_type == 2 || hint_type == 3, "invalid mem op hint type"); + // 0 is default for store or load + // Cache at all levels, likely to be accessed again. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators + // Cache write-back all coherent levels. + if constexpr ( hint_type == 0 ) + return __ldca(&ptr[idx]); + else if constexpr ( hint_type == 1 ) + return __ldcg(&ptr[idx]); + else if constexpr ( hint_type == 2 ) + return __ldcs(&ptr[idx]); + else + return ptr[idx]; // when we are doing a store, we still want to convert type but there is no need for a load +}; + +template +__device__ __forceinline__ void store_with_hint(T* ptr, const int idx, T val_to_store) { + static_assert(hint_type == 0 || hint_type == 1 || hint_type == 2, "invalid mem op hint type"); + // 0 is default for store or load + // Cache at all levels, likely to be accessed again. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators + // Cache write-back all coherent levels. + if constexpr ( hint_type == 0 ) + return __stwb(&ptr[idx]); + else if constexpr ( hint_type == 1 ) + return __stcg(&ptr[idx]); + else + return __stcs(&ptr[idx]); +} + // TODO: This would be much cleaner if we could first go from complex_compute_t -> float 2 then do conversions // I think since this would be a compile time decision, it would be fine, but it would be good to confirm. -template -inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, const int idx) { +template +inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, const unsigned int idx) { using complex_compute_t = typename FFT::value_type; using scalar_compute_t = typename complex_compute_t::value_type; @@ -245,38 +275,41 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c if constexpr ( std::is_same_v || std::is_same_v ) { // In this case we assume we have a real valued result, packed into the first half of the complex array // TODO: think about cases where we may hit this block unintentionally and how to catch this - return std::move(reinterpret_cast(ptr)[idx]); + return std::move(load_with_hint(reinterpret_cast(ptr), idx)); } else if constexpr ( std::is_same_v ) { // In this case we assume we have a real valued result, packed into the first half of the complex array // TODO: think about cases where we may hit this block unintentionally and how to catch this - return std::move(__float2half_rn(reinterpret_cast(ptr)[idx])); + return std::move(__float2half_rn(load_with_hint(reinterpret_cast(ptr), idx))); } else if constexpr ( std::is_same_v ) { // Note: we will eventually need a similar hase for __nv_bfloat16 // I think I may need to strip the const news for this to work if constexpr ( std::is_same_v ) { - return std::move(__floats2half2_rn(ptr[idx].real( ), 0.f)); + return std::move(__floats2half2_rn(load_with_hint(ptr, idx).real( ), 0.f)); } else { - return std::move(__floats2half2_rn(static_cast(ptr)[idx], 0.f)); + return std::move(__floats2half2_rn(load_with_hint(static_cast(ptr), idx), 0.f)); } } else if constexpr ( std::is_same_v, complex_compute_t> && std::is_same_v, complex_compute_t> ) { // return std::move(static_cast(ptr)[idx]); - return std::move(ptr[idx]); + float2 t = load_with_hint(reinterpret_cast(ptr), idx); // FIXME + return std::move(complex_compute_t{t.x, t.y}); } else if constexpr ( std::is_same_v, complex_compute_t> && std::is_same_v, float2> ) { // return std::move(static_cast(ptr)[idx]); - return std::move(SetTo_t{ptr[idx].real( ), ptr[idx].imag( )}); + // return std::move(SetTo_t{load_with_hint(reinterpret_cast(ptr), idx).real( ), load_with_hint(reinterpret_cast(ptr), idx).imag( )}); + return std::move(SetTo_t{load_with_hint(reinterpret_cast(ptr), idx).x, load_with_hint(reinterpret_cast(ptr), idx).y}); } + else if constexpr ( std::is_same_v, float2> && std::is_same_v, complex_compute_t> ) { // return std::move(static_cast(ptr)[idx]); - return std::move(SetTo_t{ptr[idx].x, ptr[idx].y}); + return std::move(SetTo_t{load_with_hint(ptr, idx).x, load_with_hint(ptr, idx).y}); } else if constexpr ( std::is_same_v, float2> && std::is_same_v, float2> ) { // return std::move(static_cast(ptr)[idx]); - return std::move(ptr[idx]); + return std::move(load_with_hint(ptr, idx)); } else { static_no_match( ); @@ -286,20 +319,21 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c if constexpr ( std::is_same_v || std::is_same_v ) { // In this case we assume we have a real valued result, packed into the first half of the complex array // TODO: think about cases where we may hit this block unintentionally and how to catch this - return std::move(static_cast(ptr)[idx]); + return std::move(load_with_hint(static_cast(ptr), idx)); } else if constexpr ( std::is_same_v ) { // In this case we assume we have a real valued result, packed into the first half of the complex array // TODO: think about cases where we may hit this block unintentionally and how to catch this - return std::move(__float2half_rn(static_cast(ptr)[idx])); + return std::move(__float2half_rn(load_with_hint(static_cast(ptr), idx))); } else if constexpr ( std::is_same_v ) { // Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways? - return std::move(__floats2half2_rn(static_cast(ptr)[idx], 0.f)); + return std::move(__floats2half2_rn(load_with_hint(static_cast(ptr), idx), 0.f)); } else if constexpr ( std::is_same_v, complex_compute_t> || std::is_same_v, float2> ) { // Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways? - return std::move(SetTo_t{static_cast(ptr)[idx], 0.f}); + + return std::move(SetTo_t{load_with_hint(static_cast(ptr), idx), 0.f}); } else { static_no_match( ); @@ -307,21 +341,21 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c } else if constexpr ( std::is_same_v, __half> ) { if constexpr ( std::is_same_v || std::is_same_v ) { - return std::move(__half2float(ptr[idx])); + return std::move(__half2float(load_with_hint(ptr, idx))); } else if constexpr ( std::is_same_v ) { // In this case we assume we have a real valued result, packed into the first half of the complex array // TODO: think about cases where we may hit this block unintentionally and how to catch this - return std::move(static_cast(ptr)[idx]); + return std::move(load_with_hint(static_cast(ptr), idx)); } else if constexpr ( std::is_same_v ) { // Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways? // FIXME: For some reason CUDART_ZERO_FP16 is not defined even with cuda_fp16.h included - return std::move(__halves2half2(static_cast(ptr)[idx], __ushort_as_half((unsigned short)0x0000U))); + return std::move(__halves2half2(load_with_hint(static_cast(ptr), idx), __ushort_as_half((unsigned short)0x0000U))); } else if constexpr ( std::is_same_v, complex_compute_t> || std::is_same_v, float2> ) { // Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways? - return std::move(SetTo_t{__half2float(static_cast(ptr)[idx]), 0.f}); + return std::move(SetTo_t{__half2float(load_with_hint(static_cast(ptr), idx)), 0.f}); } else { static_no_match( ); @@ -331,16 +365,16 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c if constexpr ( std::is_same_v || std::is_same_v || std::is_same_v ) { // In this case we assume we have a real valued result, packed into the first half of the complex array // TODO: think about cases where we may hit this block unintentionally and how to catch this - return std::move(reinterpret_cast(ptr)[idx]); + return std::move(load_with_hint(reinterpret_cast(ptr), idx)); } else if constexpr ( std::is_same_v ) { // Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways? // FIXME: For some reason CUDART_ZERO_FP16 is not defined even with cuda_fp16.h included - return std::move(static_cast(ptr)[idx]); + return std::move(load_with_hint(static_cast(ptr), idx)); } else if constexpr ( std::is_same_v, complex_compute_t> || std::is_same_v, float2> ) { // Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways? - return std::move(SetTo_t{__low2float(static_cast(ptr)[idx]), __high2float(static_cast(ptr)[idx])}); + return std::move(SetTo_t{__low2float(load_with_hint(static_cast(ptr), idx)), __high2float(load_with_hint(static_cast(ptr), idx))}); } else { static_no_match( ); @@ -1056,7 +1090,7 @@ struct io { for ( unsigned int i = 0; i < FFT::input_ept; i++ ) { if ( x_prime < SignalLength ) - smem_buffer[smem_idx] = convert_if_needed(input, x_prime * pixel_pitch + fft_idx + blockIdx.y * n_coalesced_ffts); + smem_buffer[smem_idx] = convert_if_needed(input, x_prime * pixel_pitch + fft_idx + blockIdx.y * n_coalesced_ffts); __syncthreads( ); if ( index < SignalLength ) @@ -1094,7 +1128,7 @@ struct io { unsigned int index = threadIdx.x; for ( unsigned int i = 0; i < FFT::elements_per_thread; i++ ) { - thread_data[i] = convert_if_needed(input, index); + thread_data[i] = convert_if_needed(input, index); index += FFT::stride; } } @@ -1112,7 +1146,7 @@ struct io { float2 temp; for ( unsigned int i = 0; i < FFT::elements_per_thread; i++ ) { if ( index < last_index_to_load ) { - temp = pre_op_functor(convert_if_needed(input, index)); + temp = pre_op_functor(convert_if_needed(input, index)); thread_data[i] = convert_if_needed(&temp, 0); } else {