Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
bHimes committed Jan 9, 2025
1 parent 7c4908b commit 0e5210a
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 105 deletions.
85 changes: 27 additions & 58 deletions include/FastFFT.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ __launch_bounds__(FFT::max_threads_per_block) __global__
const ComplexData_t* __restrict__ input_values,
ComplexData_t* __restrict__ output_values,
Offsets mem_offsets,
int Q,
typename FFT::workspace_type workspace_fwd,
typename invFFT::workspace_type workspace_inv,
PreOpType pre_op_functor,
Expand Down Expand Up @@ -226,39 +225,9 @@ __global__ void clip_into_real_kernel(PositionSpaceType* real_values_gpu,
int3 wanted_coordinate_of_box_center,
OutputBaseType wanted_padding_value);

template <unsigned int hint_type, typename T>
__device__ __forceinline__ T load_with_hint(const T* ptr, const int idx) {
static_assert(hint_type == 0 || hint_type == 1 || hint_type == 2 || hint_type == 3, "invalid mem op hint type");
// 0 is default for store or load
// Cache at all levels, likely to be accessed again. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators
// Cache write-back all coherent levels.
if constexpr ( hint_type == 0 )
return __ldca(&ptr[idx]);
else if constexpr ( hint_type == 1 )
return __ldcg(&ptr[idx]);
else if constexpr ( hint_type == 2 )
return __ldcs(&ptr[idx]);
else
return ptr[idx]; // when we are doing a store, we still want to convert type but there is no need for a load
};

template <unsigned int hint_type, typename T>
__device__ __forceinline__ void store_with_hint(T* ptr, const int idx, T val_to_store) {
static_assert(hint_type == 0 || hint_type == 1 || hint_type == 2, "invalid mem op hint type");
// 0 is default for store or load
// Cache at all levels, likely to be accessed again. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators
// Cache write-back all coherent levels.
if constexpr ( hint_type == 0 )
return __stwb(&ptr[idx]);
else if constexpr ( hint_type == 1 )
return __stcg(&ptr[idx]);
else
return __stcs(&ptr[idx]);
}

// TODO: This would be much cleaner if we could first go from complex_compute_t -> float 2 then do conversions
// I think since this would be a compile time decision, it would be fine, but it would be good to confirm.
template <class FFT, typename SetTo_t, typename GetFrom_t, unsigned int load_hint = 3>
template <class FFT, typename SetTo_t, typename GetFrom_t>
inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, const unsigned int idx) {
using complex_compute_t = typename FFT::value_type;
using scalar_compute_t = typename complex_compute_t::value_type;
Expand All @@ -277,41 +246,41 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c
if constexpr ( std::is_same_v<SetTo_t, scalar_compute_t> || std::is_same_v<SetTo_t, float> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(load_with_hint<load_hint>(reinterpret_cast<const SetTo_t*>(ptr), idx));
return std::move(reinterpret_cast<const SetTo_t*>(ptr)[idx]);
}
else if constexpr ( std::is_same_v<SetTo_t, __half> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(__float2half_rn(load_with_hint<load_hint>(reinterpret_cast<const float*>(ptr), idx)));
return std::move(__float2half_rn(reinterpret_cast<const float*>(ptr)[idx]));
}
else if constexpr ( std::is_same_v<SetTo_t, __half2> ) {
// Note: we will eventually need a similar hase for __nv_bfloat16
// I think I may need to strip the const news for this to work
if constexpr ( std::is_same_v<GetFrom_t, complex_compute_t> ) {
return std::move(__floats2half2_rn(load_with_hint<load_hint>(ptr, idx).real( ), 0.f));
return std::move(__floats2half2_rn((ptr[idx]).real( ), 0.f));
}
else {
return std::move(__floats2half2_rn(load_with_hint<load_hint>(static_cast<const float*>(ptr), idx), 0.f));
return std::move(__floats2half2_rn(static_cast<const float*>(ptr)[idx], 0.f));
}
}
else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, complex_compute_t> && std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> ) {
// return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
float2 t = load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx); // FIXME
float2 t = (reinterpret_cast<const float2*>(ptr)[idx]); // FIXME
return std::move(complex_compute_t{t.x, t.y});
}
else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, complex_compute_t> && std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
// return std::move(SetTo_t{load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx).real( ), load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx).imag( )});
return std::move(SetTo_t{load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx).x, load_with_hint<load_hint>(reinterpret_cast<const float2*>(ptr), idx).y});
// return std::move(SetTo_t{(reinterpret_cast<const float2*>(ptr)[idx]).real( ), (reinterpret_cast<const float2*>(ptr)[idx]).imag( )});
return std::move(SetTo_t{(reinterpret_cast<const float2*>(ptr)[idx]).x, (reinterpret_cast<const float2*>(ptr)[idx]).y});
}

else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, float2> && std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> ) {
// return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
return std::move(SetTo_t{load_with_hint<load_hint>(ptr, idx).x, load_with_hint<load_hint>(ptr, idx).y});
return std::move(SetTo_t{(ptr[idx]).x, (ptr[idx]).y});
}
else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, float2> && std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// return std::move(static_cast<const SetTo_t*>(ptr)[idx]);
return std::move(load_with_hint<load_hint>(ptr, idx));
return std::move((ptr[idx]));
}
else {
static_no_match( );
Expand All @@ -321,43 +290,43 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c
if constexpr ( std::is_same_v<SetTo_t, scalar_compute_t> || std::is_same_v<SetTo_t, float> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(load_with_hint<load_hint>(static_cast<const SetTo_t*>(ptr), idx));
return std::move((static_cast<const SetTo_t*>(ptr)[idx]));
}
else if constexpr ( std::is_same_v<SetTo_t, __half> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(__float2half_rn(load_with_hint<load_hint>(static_cast<const float*>(ptr), idx)));
return std::move(__float2half_rn((static_cast<const float*>(ptr)[idx])));
}
else if constexpr ( std::is_same_v<SetTo_t, __half2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
return std::move(__floats2half2_rn(load_with_hint<load_hint>(static_cast<const float*>(ptr), idx), 0.f));
return std::move(__floats2half2_rn((static_cast<const float*>(ptr)[idx]), 0.f));
}
else if constexpr ( std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> || std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?

return std::move(SetTo_t{load_with_hint<load_hint>(static_cast<const float*>(ptr), idx), 0.f});
return std::move(SetTo_t{(static_cast<const float*>(ptr)[idx]), 0.f});
}
else {
static_no_match( );
}
}
else if constexpr ( std::is_same_v<std::decay_t<GetFrom_t>, __half> ) {
if constexpr ( std::is_same_v<SetTo_t, scalar_compute_t> || std::is_same_v<SetTo_t, float> ) {
return std::move(__half2float(load_with_hint<load_hint>(ptr, idx)));
return std::move(__half2float((ptr[idx])));
}
else if constexpr ( std::is_same_v<SetTo_t, __half> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(load_with_hint<load_hint>(static_cast<const SetTo_t*>(ptr), idx));
return std::move((static_cast<const SetTo_t*>(ptr)[idx]));
}
else if constexpr ( std::is_same_v<SetTo_t, __half2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
// FIXME: For some reason CUDART_ZERO_FP16 is not defined even with cuda_fp16.h included
return std::move(__halves2half2(load_with_hint<load_hint>(static_cast<const __half*>(ptr), idx), __ushort_as_half((unsigned short)0x0000U)));
return std::move(__halves2half2((static_cast<const __half*>(ptr)[idx]), __ushort_as_half((unsigned short)0x0000U)));
}
else if constexpr ( std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> || std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
return std::move(SetTo_t{__half2float(load_with_hint<load_hint>(static_cast<const __half*>(ptr), idx)), 0.f});
return std::move(SetTo_t{__half2float((static_cast<const __half*>(ptr)[idx])), 0.f});
}
else {
static_no_match( );
Expand All @@ -367,16 +336,16 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c
if constexpr ( std::is_same_v<SetTo_t, scalar_compute_t> || std::is_same_v<SetTo_t, float> || std::is_same_v<SetTo_t, __half> ) {
// In this case we assume we have a real valued result, packed into the first half of the complex array
// TODO: think about cases where we may hit this block unintentionally and how to catch this
return std::move(load_with_hint<load_hint>(reinterpret_cast<const SetTo_t*>(ptr), idx));
return std::move((reinterpret_cast<const SetTo_t*>(ptr)[idx]));
}
else if constexpr ( std::is_same_v<SetTo_t, __half2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
// FIXME: For some reason CUDART_ZERO_FP16 is not defined even with cuda_fp16.h included
return std::move(load_with_hint<load_hint>(static_cast<const SetTo_t*>(ptr), idx));
return std::move((static_cast<const SetTo_t*>(ptr)[idx]));
}
else if constexpr ( std::is_same_v<std::decay_t<SetTo_t>, complex_compute_t> || std::is_same_v<std::decay_t<SetTo_t>, float2> ) {
// Here we assume we are reading a real value and placeing it in a complex array. Could this go sideways?
return std::move(SetTo_t{__low2float(load_with_hint<load_hint>(static_cast<const __half2*>(ptr), idx)), __high2float(load_with_hint<load_hint>(static_cast<const __half2*>(ptr), idx))});
return std::move(SetTo_t{__low2float((static_cast<const __half2*>(ptr)[idx])), __high2float((static_cast<const __half2*>(ptr)[idx]))});
}
else {
static_no_match( );
Expand Down Expand Up @@ -864,9 +833,9 @@ struct io {

// TODO: set user lambda to default = false, then get rid of other load_shared
template <typename ExternalImage_t, class FunctionType = std::nullptr_t>
static inline __device__ void load_shared(const ExternalImage_t* __restrict__ image_to_search,
complex_compute_t* __restrict__ thread_data,
FunctionType intra_op_functor = nullptr) {
static inline __device__ void load_external_data(const ExternalImage_t* __restrict__ image_to_search,
complex_compute_t* __restrict__ thread_data,
FunctionType intra_op_functor = nullptr) {

unsigned int index = threadIdx.x;
if constexpr ( IS_IKF_t<FunctionType>( ) ) {
Expand Down Expand Up @@ -1085,7 +1054,7 @@ struct io {
for ( unsigned int i = 0; i < FFT::input_ept; i++ ) {

if ( x_prime < SignalLength )
smem_buffer[smem_idx] = convert_if_needed<FFT, complex_compute_t, data_io_t, 1>(input, x_prime * pixel_pitch + fft_idx + blockIdx.y * n_coalesced_ffts);
smem_buffer[smem_idx] = convert_if_needed<FFT, complex_compute_t>(input, x_prime * pixel_pitch + fft_idx + blockIdx.y * n_coalesced_ffts);
__syncthreads( );

if ( index < SignalLength )
Expand Down Expand Up @@ -1123,7 +1092,7 @@ struct io {

unsigned int index = threadIdx.x;
for ( unsigned int i = 0; i < FFT::elements_per_thread; i++ ) {
thread_data[i] = convert_if_needed<FFT, complex_compute_t, data_io_t, 2>(input, index);
thread_data[i] = convert_if_needed<FFT, complex_compute_t>(input, index);
index += FFT::stride;
}
}
Expand All @@ -1141,7 +1110,7 @@ struct io {
float2 temp;
for ( unsigned int i = 0; i < FFT::elements_per_thread; i++ ) {
if ( index < last_index_to_load ) {
temp = pre_op_functor(convert_if_needed<FFT, float2, data_io_t, 2>(input, index));
temp = pre_op_functor(convert_if_needed<FFT, float2>(input, index));
thread_data[i] = convert_if_needed<FFT, complex_compute_t>(&temp, 0);
}
else {
Expand Down
2 changes: 1 addition & 1 deletion include/FastFFT.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
// #define FastFFT_build_sizes 32, 64, 128, 256, 512, 1024, 2048, 4096

// #define FastFFT_build_sizes 16, 4, 32, 8, 64, 8, 128, 8, 256, 8, 512, 8, 1024, 8, 2048, 8, 4096, 16, 8192, 16
#define FastFFT_build_sizes 64, 128
#define FastFFT_build_sizes 512, 4096

namespace FastFFT {

Expand Down
Loading

0 comments on commit 0e5210a

Please sign in to comment.