From f4b1aa4bf8744d3b6f0ec3d89d9d409ef53c92b3 Mon Sep 17 00:00:00 2001 From: himesb Date: Sat, 21 Dec 2024 16:22:08 -0500 Subject: [PATCH] unrolled and fixed assignment to fft buffer --- include/FastFFT.cuh | 46 +++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/include/FastFFT.cuh b/include/FastFFT.cuh index 3a2bd54..2892870 100644 --- a/include/FastFFT.cuh +++ b/include/FastFFT.cuh @@ -450,6 +450,12 @@ struct WarpTiler { return (read_multiplier * FFT::storage_size * physical_x_in_tile / 2) + read_multiplier * i; } + static inline __device__ unsigned int thread_data_idx_real_part_fft_0(const unsigned int i) { + return read_multiplier * i; + } + + constexpr unsigned int thread_data_fft_stride = read_multiplier * FFT::storage_size; + // This can't be right, which means data_index_to_read_1d must be wrong static inline __device__ int get_producer_thread_tile_index_y(const unsigned int i_read) { // A consumer thread will be wanting the data from input_data[threadIdx.x + i*FFT::stride] @@ -1042,28 +1048,32 @@ struct io { __syncwarp( ); int producer_thread_re_lane_idx warp_tiler::get_producer_thread_re_lane_idx(i_tile); - - float copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_thread_re_lane_idx, 32); + const bool is_active_consumer = producer_thread_re_lane_idx < warp_tiler::warp_size; unsigned int thread_data_linear_idx = warp_tiler::thread_data_linear_idx_real_part(i); - if ( producer_thread_re_lane_idx < warp_tiler::warp_size && warp_tiler::invalid_read_v != copied_val ) { - // printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n", - // threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept); - // printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val); - thread_data[thread_data_linear_idx] = copied_val; + unsigned int thread_index = warp_tiler::thread_data_idx_real_part_fft_0(i); +#pragma unroll(n_coalesced_ffts) + for ( int i_fft = 0; i_fft < n_coalesced_ffts; i_fft++ ) { + float copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_thread_re_lane_idx, 32); + + if ( is_active_consumer && warp_tiler::invalid_read_v != copied_val ) { + // printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n", + // threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept); + // printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val); + thread_data[thread_index] = copied_val; + } + copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx + 1, 32); + if ( is_active_consumer && warp_tiler::invalid_read_v != copied_val ) { + // printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n", + // threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept); + // printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val); + thread_data[thread_index + 1] = copied_val; + } + __syncwarp( ); + + thread_index += warp_tiler::thread_data_fft_stride; } - __syncwarp( ); - // n_coalesced_ffts * read_multiplier is just tile pitch which I'm calculting like 4 different ways here - - copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx + 1, 32); - if ( producer_thread_re_lane_idx < warp_tiler::warp_size && warp_tiler::invalid_read_v != copied_val ) { - // printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n", - // threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept); - // printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val); - thread_data[thread_data_linear_idx + 1] = copied_val; - } - __syncwarp( ); } } }