diff --git a/src/fastfft/FastFFT.cu b/src/fastfft/FastFFT.cu index f863380..4162224 100644 --- a/src/fastfft/FastFFT.cu +++ b/src/fastfft/FastFFT.cu @@ -1656,7 +1656,19 @@ __launch_bounds__(MAX_TPB) __global__ thread_data, shared_mem, gridDim.y * n_ffts); - FFT( ).execute(thread_data, &shared_mem[FFT::shared_memory_size / sizeof(complex_compute_t) * threadIdx.y], workspace); + + // To cut down on total smem needed split these + // FFT( ).execute(thread_data, &shared_mem[FFT::shared_memory_size / sizeof(complex_compute_t) * threadIdx.y], workspace); + + const unsigned int eve_odd = threadIdx.y % 2; + if ( eve_odd == 0 ) { + FFT( ).execute(thread_data, &shared_mem[FFT::shared_memory_size / sizeof(complex_compute_t) * (threadIdx.y / 2)], workspace); + } + __syncthreads( ); + if ( eve_odd == 1 ) { + FFT( ).execute(thread_data, &shared_mem[FFT::shared_memory_size / sizeof(complex_compute_t) * (threadIdx.y / 2)], workspace); + } + __syncthreads( ); #else static_assert(n_ffts == 1, "C2R_BUFFER_LINES must be enabled, should not get hereonly for n_ffts == 1"); @@ -2623,8 +2635,8 @@ void FourierTransformer::SetAn // TODO: other asserts, but basically, for the buffering we want to have at least 32 threads per block, and we can't modify x } - // Add enough shared mem to swap on each read - shared_memory = std::max(size_t(FFT::shared_memory_size * n_buffer_lines), size_t(FFT::stride * n_buffer_lines * sizeof(complex_compute_t))); + // Add enough shared mem to swap on each read // revert + shared_memory = std::max(size_t(FFT::shared_memory_size * n_buffer_lines / 2), size_t(FFT::stride * n_buffer_lines * sizeof(complex_compute_t))); #else @@ -2636,6 +2648,7 @@ void FourierTransformer::SetAn // PrintLaunchParameters(LP); // PrintState( ); // std::cerr << "smem " << shared_memory << std::endl; + // std::cerr << "FFT::shared_memory_size " << FFT::shared_memory_size << std::endl; // std::cerr << "max tpb " << max_threads_per_block << " n_buffer " << n_buffer_lines << std::endl; // exit(0); @@ -2830,9 +2843,11 @@ void FourierTransformer::SetAn } case generic_fwd_increase_op_inv_none: { if constexpr ( FFT_ALGO_t == Generic_Fwd_Image_Inv_FFT ) { + // revert : with 16 ept this method uses 128 reg/thread which limits to 2 blocks /sm, testing 8 + using mod_base = cufftdx::replace_t>; // For convenience, we are explicitly zero-padding. This is lazy. FIXME - using FFT = decltype(FFT_base_arch( ) + Type( ) + Direction( )); - using invFFT = decltype(FFT_base_arch( ) + Type( ) + Direction( )); + using FFT = decltype(mod_base( ) + Type( ) + Direction( )); + using invFFT = decltype(mod_base( ) + Type( ) + Direction( )); LaunchParams LP = SetLaunchParameters(generic_fwd_increase_op_inv_none, FFT::elements_per_thread);