Skip to content

Commit

Permalink
cuts down a bit on shared mem needed for coalesced c2r
Browse files Browse the repository at this point in the history
  • Loading branch information
bHimes committed Dec 27, 2024
1 parent 7554b34 commit f39f2d7
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions src/fastfft/FastFFT.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,19 @@ __launch_bounds__(MAX_TPB) __global__
thread_data,
shared_mem,
gridDim.y * n_ffts);
FFT( ).execute(thread_data, &shared_mem[FFT::shared_memory_size / sizeof(complex_compute_t) * threadIdx.y], workspace);

// To cut down on total smem needed split these
// FFT( ).execute(thread_data, &shared_mem[FFT::shared_memory_size / sizeof(complex_compute_t) * threadIdx.y], workspace);

const unsigned int eve_odd = threadIdx.y % 2;
if ( eve_odd == 0 ) {
FFT( ).execute(thread_data, &shared_mem[FFT::shared_memory_size / sizeof(complex_compute_t) * (threadIdx.y / 2)], workspace);
}
__syncthreads( );
if ( eve_odd == 1 ) {
FFT( ).execute(thread_data, &shared_mem[FFT::shared_memory_size / sizeof(complex_compute_t) * (threadIdx.y / 2)], workspace);
}
__syncthreads( );

#else
static_assert(n_ffts == 1, "C2R_BUFFER_LINES must be enabled, should not get hereonly for n_ffts == 1");
Expand Down Expand Up @@ -2623,8 +2635,8 @@ void FourierTransformer<ComputeBaseType, InputType, OtherImageType, Rank>::SetAn
// TODO: other asserts, but basically, for the buffering we want to have at least 32 threads per block, and we can't modify x
}

// Add enough shared mem to swap on each read
shared_memory = std::max(size_t(FFT::shared_memory_size * n_buffer_lines), size_t(FFT::stride * n_buffer_lines * sizeof(complex_compute_t)));
// Add enough shared mem to swap on each read // revert
shared_memory = std::max(size_t(FFT::shared_memory_size * n_buffer_lines / 2), size_t(FFT::stride * n_buffer_lines * sizeof(complex_compute_t)));

#else

Expand All @@ -2636,6 +2648,7 @@ void FourierTransformer<ComputeBaseType, InputType, OtherImageType, Rank>::SetAn
// PrintLaunchParameters(LP);
// PrintState( );
// std::cerr << "smem " << shared_memory << std::endl;
// std::cerr << "FFT::shared_memory_size " << FFT::shared_memory_size << std::endl;
// std::cerr << "max tpb " << max_threads_per_block << " n_buffer " << n_buffer_lines << std::endl;
// exit(0);

Expand Down Expand Up @@ -2830,9 +2843,11 @@ void FourierTransformer<ComputeBaseType, InputType, OtherImageType, Rank>::SetAn
}
case generic_fwd_increase_op_inv_none: {
if constexpr ( FFT_ALGO_t == Generic_Fwd_Image_Inv_FFT ) {
// revert : with 16 ept this method uses 128 reg/thread which limits to 2 blocks /sm, testing 8
using mod_base = cufftdx::replace_t<FFT_base_arch, ElementsPerThread<16>>;
// For convenience, we are explicitly zero-padding. This is lazy. FIXME
using FFT = decltype(FFT_base_arch( ) + Type<fft_type::c2c>( ) + Direction<fft_direction::forward>( ));
using invFFT = decltype(FFT_base_arch( ) + Type<fft_type::c2c>( ) + Direction<fft_direction::inverse>( ));
using FFT = decltype(mod_base( ) + Type<fft_type::c2c>( ) + Direction<fft_direction::forward>( ));
using invFFT = decltype(mod_base( ) + Type<fft_type::c2c>( ) + Direction<fft_direction::inverse>( ));

LaunchParams LP = SetLaunchParameters(generic_fwd_increase_op_inv_none, FFT::elements_per_thread);

Expand Down

0 comments on commit f39f2d7

Please sign in to comment.