Skip to content

Commit

Permalink
unrolled and fixed assignment to fft buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
bHimes committed Dec 21, 2024
1 parent ed18314 commit f4b1aa4
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions include/FastFFT.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,12 @@ struct WarpTiler {
return (read_multiplier * FFT::storage_size * physical_x_in_tile / 2) + read_multiplier * i;
}

static inline __device__ unsigned int thread_data_idx_real_part_fft_0(const unsigned int i) {
return read_multiplier * i;
}

constexpr unsigned int thread_data_fft_stride = read_multiplier * FFT::storage_size;

// This can't be right, which means data_index_to_read_1d must be wrong
static inline __device__ int get_producer_thread_tile_index_y(const unsigned int i_read) {
// A consumer thread will be wanting the data from input_data[threadIdx.x + i*FFT::stride]
Expand Down Expand Up @@ -1042,28 +1048,32 @@ struct io {

__syncwarp( );
int producer_thread_re_lane_idx warp_tiler::get_producer_thread_re_lane_idx(i_tile);

float copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_thread_re_lane_idx, 32);
const bool is_active_consumer = producer_thread_re_lane_idx < warp_tiler::warp_size;

unsigned int thread_data_linear_idx = warp_tiler::thread_data_linear_idx_real_part(i);

if ( producer_thread_re_lane_idx < warp_tiler::warp_size && warp_tiler::invalid_read_v != copied_val ) {
// printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n",
// threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept);
// printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val);
thread_data[thread_data_linear_idx] = copied_val;
unsigned int thread_index = warp_tiler::thread_data_idx_real_part_fft_0(i);
#pragma unroll(n_coalesced_ffts)
for ( int i_fft = 0; i_fft < n_coalesced_ffts; i_fft++ ) {
float copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_thread_re_lane_idx, 32);

if ( is_active_consumer && warp_tiler::invalid_read_v != copied_val ) {
// printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n",
// threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept);
// printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val);
thread_data[thread_index] = copied_val;
}
copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx + 1, 32);
if ( is_active_consumer && warp_tiler::invalid_read_v != copied_val ) {
// printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n",
// threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept);
// printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val);
thread_data[thread_index + 1] = copied_val;
}
__syncwarp( );

thread_index += warp_tiler::thread_data_fft_stride;
}
__syncwarp( );
// n_coalesced_ffts * read_multiplier is just tile pitch which I'm calculting like 4 different ways here

copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx + 1, 32);
if ( producer_thread_re_lane_idx < warp_tiler::warp_size && warp_tiler::invalid_read_v != copied_val ) {
// printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n",
// threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept);
// printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val);
thread_data[thread_data_linear_idx + 1] = copied_val;
}
__syncwarp( );
}
}
}
Expand Down

0 comments on commit f4b1aa4

Please sign in to comment.