diff --git a/include/FastFFT.cuh b/include/FastFFT.cuh index 76501ce..3a2bd54 100644 --- a/include/FastFFT.cuh +++ b/include/FastFFT.cuh @@ -349,12 +349,146 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c } } -////////////////////////////////////////////// -// IO functions adapted from the cufftdx examples -/////////////////////////////// +template +inline void static_assert_allowed_read_size( ) { static_assert(flag, "static_assert_allowed_read_size"); } + +template +struct WarpTiler { + using complex_compute_t = typename FFT::value_type; + using scalar_compute_t = typename complex_compute_t::value_type; + + // Drafting for C2R XY specifically + + // Describe in terms of 2D as 3D is handled as a stack of 2D + + // Some steps are substantially simplified if all threads that are producers are also consumers. This means - whoops only half the threads will be consumers then + // we need blockDim.x >= 32, and with EPT restricted to pow 2, min size is 64 + // For these smaller arrays, L2 cache hits are far more likely, can be encouraged further > 8.0 or alternatively handled + // by a intra-block smem transpose if the thread fft api is used. In any case, for the current project, we mostly care about large + // FFts anyway + + // We have an array of M complex valued FFTs, all with their 0th index along the fast dimension (for this example) + // The FFT dimension is N and for C2R in a natural layout we have N/2 + 1 complex inputs per FFT + // blockDim.x is fixed to be N / ept (FFT::elements_per_thread) = FFT::stride + // For C2R we have ept/2+1 = input_ept, and have at most that many read cycles/FFT, while many threads are excluded on the last read (We may be able to fix that with this method) + // Each read cycle has blockDim.x reads of contigous threads and each thread has an array T[FFT::storage_size] where the data is expected to be T[i] = data[tid.x + i*FFT::stride] + + // For block methods _XY the data are transposed in memory and so the reads are not coalesced + // Normally, we have gridDim.y blocks, but we reduce that here by n_coalesced_ffts + + // We read in values in 4byte / thread to maximized coalescing and also b/c I think that is all a warp shuffle supports (though I could be wrong on that.) + + // The goal is to have n_coalesced_ffts as large as possible without killing occupancy due to register pressure + + // each read cycle, a warp will be have to read n_coalseced_ffts * 2 4byte chunks leaving warp_size (32) / n_coalseced_ffts * 2 complex elements read for each FFT + + // We require that blockDim.x = FFT::stride = n_expected reads / cycle be >= this value. EPT is restricted to a power of 2, so if it is > than it must be so by at least 2 and it will be a nice whole number + + // FIXME: need to add checks on data layout (RIRI vs RRII) + // FIXME: need to add checks on data sizes and allow more flexiblity + constexpr unsigned int read_bytes = 4; + static_assert(sizeof(complex_compute_t) == 2 * read_bytes, "warp_tiler requires complex_compute_t to be 8 bytes atm."); + static_assert(sizeof(scalar_compute_t) == read_bytes, "warp_tiler requires scalar_compute_t to be 4 bytes atm."); + constexpr unsigned int read_multiplier = 2; + + constexpr unsigned int expected_reads_per_cycle = FFT::stride; + constexpr unsigned int warp_size = 32; // I think this is true for all arch as of 2024 FIXME: confirm + + constexpr unsigned int n_4byte_reads_per_complex_element = + sizeof(scalar_compute_t) == read_bytes ? 2 : static_assert_allowed_read_size( ); // overkill for now, but for handling half later + + constexpr bool strided_XY = true; // for later configuration; + constexpr unsigned int tile_thread_dim_fast = n_coalesced_ffts * n_4byte_reads_per_complex_element; + constexpr unsigned int tile_thread_dim_strided = warp_size / tile_thread_dim_fast; + + static_assert(tile_thread_dim_strided <= expected_reads_per_cycle, "there are not enough x threads for this method."); + static_assert(expected_reads_per_cycle >= warp_size, "Min blockDim.x 32 for this optimization"); + + // some aliases but I guess we could just omit the above two lines. + constexpr unsigned int tile_thread_dim_x = strided_XY ? tile_thread_dim_fast : tile_thread_dim_strided; + constexpr unsigned int tile_thread_dim_y = strided_XY ? tile_thread_dim_strided : tile_thread_dim_fast; + + constexpr unsigned int n_tile_reads_per_cycle = expected_reads_per_cycle / tile_thread_dim_strided; + + // Some threads may not have a valid read in a given cycle, but we other threads won't know that. + // Even a thread without a valid read on the tile, may be a consumer through a warp shuffle, so we can't exclude it with logic. + // Instead, use a set value to communicate this + constexpr float invalid_read_v = + sizeof(scalar_compute_t) == read_bytes ? scalar_compute_t{-std::numeric_limits::max( )} : static_assert_allowed_read_size( ); // overkill for now, but for handling half later + + const unsigned int lane_idx = get_lane_id( ); // We may have threads coming from x/y so we get the lane ide from special reg, which could be a tiny bit slower, but more dependible than : threadIdx.x & 31; + + // This should be safe b/c if blockDim.x >= 32 we should have blockDim.y = 1, For now, I have an assert prior to the kernel launch to save overhead + MyFFTDebugAssertTestTrue(blockDim.x* blockDim.y == 0 || blockDim.x * blockDim.y == 32, "Incorrect partitioning of x/y threads for warp tiler."); + const unsigned int tile_idx = blockDim.y == 0 ? threadIdx.x / warp_size : 0; + + // TOOD: rename and make invertable + const unsigned int physical_y_in_tile = lane_idx / tile_thread_dim_x; + const unsigned int physical_x_in_tile = line_idx & (warp_size - 1); + + // in the normal approach this will be threadIdx.x + i*FFT::stride (which we'll need to define the consumer thread) + // I think this will be at most 2 instuctions IMAD, but maybe it is better to calc an intial value and increment it + static inline __device__ unsigned int data_index_to_read_1d(const unsigned int i, const unsigned int i_read) { + return (physical_y_in_tile + tile_idx * warp_size) + i_read * n_tile_reads_per_cycle + i * FFT::stride; + } + + // FIXME: there should be an assert in the main program prior to launching this kernel that things are square + // leaving it templated here so it is more obvious if those restrictions are relaxed in the future. + template + static inline __device__ unsigned int data_index_to_read_2d(const unsigned int data_index_to_read_1d_value) { + if constexpr ( PixelPitch == 0 ) { + // Generally we have a square FFT, so we can infer the pixel pitch at compile time + return (size_of::value * read_multiplier) * data_index_to_read_1d_value + warp_tiler::physical_x_in_tile + } + else { + return PixelPitch * data_index_to_read_1d_value + warp_tiler::physical_x_in_tile; + } + } + + static inline __device__ unsigned int thread_data_linear_idx_real_part(const unsigned int i) { + // The consumer's i_fft (index of coalesced FFT) is just physical_x_in_tile (FIXME, may be C2R XY specific here) + return (read_multiplier * FFT::storage_size * physical_x_in_tile / 2) + read_multiplier * i; + } + + // This can't be right, which means data_index_to_read_1d must be wrong + static inline __device__ int get_producer_thread_tile_index_y(const unsigned int i_read) { + // A consumer thread will be wanting the data from input_data[threadIdx.x + i*FFT::stride] + // The producer thread that fetched this index was (physical_y_in_tile + tile_idx * warp_size) + i_read * n_tile_reads_per_cycle + i * FFT::stride; + // Solving for physical_y_in_tile = threadIdx.x - tile_idx * warp_size - i_read * n_tile_reads_per_cycle + + // This will only be a valid index for a return value >= 0 < n_tile_reads_per_cycle which determines if we are an active consumer thread + return threadIdx.x - tile_idx * warp_size - i_read * n_tile_reads_per_cycle; + } + + static inline __device__ unsigned int get_my_fft_idx( ) { + return physical_x_in_tile / 2; + } + + // the real part will be at the even (physical_x_in_tile, thread_tile_index_y) + // The imaginary part will just be + 1 + static inline __device__ unsigned int get_producer_thread_re_lane_idx(const unsigned int i_read) { + int producer_thread_tile_index_y = get_producer_thread_tile_index_y(i_read); + if ( producer_thread_tile_index_y < 0 || n_tile_reads_per_cycle >= n_tile_reads_per_cycle ) { + return warp_size; + } + else { + // just a conversion from a 2d index in the tile to a 1d index == producers lane + return read_multiplier * get_my_fft_idx( ) + producer_thread_tile_index_y * tile_thread_dim_x; + } + } + + // Now every thread gets the producer value + // Every thread reads with a shfl_sync to a register copied value + // if (producer lane < warp_size) -> write to thread array + // maybe warpsync? + // Every thread reads again, but + 1 + // if (producer lane) + +} template struct io { + using complex_compute_t = typename FFT::value_type; using scalar_compute_t = typename complex_compute_t::value_type; @@ -891,96 +1025,54 @@ struct io { const unsigned int pixel_pitch, const unsigned int SignalLength = FFT::input_length) { - // Work this in steps toward generality - // Here we assume the input originally gridDim.y * FFT::input_length (complex) and there are blockDim.x = FFT::stride (y,z are 1) - // We need read multiplier because we can only swap 4 bytes (afaik) so we'll read re/im separately - constexpr unsigned int read_multiplier = 2; // to make it clear where we are indexing extra because we are reading re/im seperately. - - // If we have fewer than - - const unsigned int lane_idx = get_lane_id( ); // We may have threads coming from x/y so we get the lane ide from special reg, which could be a tiny bit slower, but more dependible than : threadIdx.x & 31; - // This should be safe b/c if blockDim.x >= 32 we should have blockDim.y = 1, For now, I have an assert prior to the kernel launch to save overhead - const unsigned int warp_idx = threadIdx.x / 32; - const unsigned int logical_x_in_1d = (threadIdx.x & 31); - // Normally we would have a gridDim.y = pixel_pitch, but since here we have reduced the number of blocks to - // pixel_pitch / n_coalesced_ffts, - const unsigned int tile_idx_x = lane_idx % (n_coalesced_ffts * read_multiplier); - const unsigned int physical_x = tile_idx_x + (blockIdx.y * n_coalesced_ffts * read_multiplier); - - constexpr unsigned int tile_size_y = 16 / n_coalesced_ffts; // # threads in each dimension 4 - constexpr unsigned int tile_size_x = 32 / tile_size_y; // 8 - constexpr unsigned int n_complex_vals_per_read = tile_size_y / 2; // 2 - constexpr unsigned int n_sub_warp_blocks = 32 / n_complex_vals_per_read; // also pitch in elements for the tile 16 - constexpr float no_val = scalar_compute_t{-std::numeric_limits::max( )}; - const unsigned int expected_reads_per_i = blockDim.x; // also can get at compile time 4 (8 threads in y) - - // In the normal implemenation there is only on read per loop, and every thread is both a producer and a consumer. - // so for 64 with 8 ept, there are 8 threads and a stride of 8 In this implementation there is a minimum of 32 threads, so + WarpTiler warp_tiler( ); //revert FFT::input_ept for ( unsigned int i = 0; i < FFT::input_ept; i++ ) { - unsigned int physical_y_in_warp = lane_idx / tile_size_x; // 0 1 2 3 - unsigned int base_physical_y_in_warp = physical_y_in_warp; - // All reads need to be within the warp to use shfl_sync, so we need to break up the input data into 2d blocks, - for ( unsigned int i_tile = 0; i_tile < std::min(std::max(1u, expected_reads_per_i / n_complex_vals_per_read), n_sub_warp_blocks); i_tile++ ) { + for ( unsigned int i_tile = 0; i_tile < warp_tiler::n_tile_reads_per_cycle; i_tile++ ) { // 4 / 2 = 2, min(2, 16) = 2 // read in as floats - unsigned int read_from_data_index = (physical_y_in_warp + warp_idx * 32) + i * FFT::stride; - // i0 0: 0 1 2 3 - // 1: 4 5 6 7 - // i 1: 4 5 6 7 (8 9 10 11) - const float read_val = - read_from_data_index < SignalLength ? reinterpret_cast(input)[read_from_data_index * pixel_pitch + physical_x] : no_val; + const unsigned int data_index_to_read_1d = warp_tiler::data_index_to_read_1d(i, i_tile); + const float read_val = + data_index_to_read_1d < SignalLength ? reinterpret_cast(input)[warp_tiler::data_index_to_read_2d(data_index_to_read_1d)] : warp_tiler::invalid_read_v; // if ( no_val != read_val ) // printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n", - // threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, read_from_data_index, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept); - - for ( int i_fft = 0; i_fft < n_coalesced_ffts; i_fft++ ) { - // all threads in the warp will now have data and we need to know who gets that data - // this will result in some threads calculating something they don't need, but thats free - // fft in the linear storage = FFT::storage_size * i_fft * 2 (2 because we are reading as floats) - // + 2 * i (2 because we are reading as floats) - unsigned int thread_data_linear_idx = read_multiplier * (FFT::storage_size * i_fft + i); - // We may have read in a dummy value if we are beyond the signal length - __syncwarp( ); - - unsigned int producer_tidx = i_fft + 2 * tile_size_x * (base_physical_y_in_warp / 2); - float copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx, 32); - - if ( threadIdx.y == 0 && - (threadIdx.x & 31) >= i_tile * n_complex_vals_per_read && - (threadIdx.x & 31) < (i_tile + 1) * n_complex_vals_per_read && - no_val != copied_val ) { - printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n", - threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, read_from_data_index, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept); - printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val); - thread_data[thread_data_linear_idx] = copied_val; - } - __syncwarp( ); - // n_coalesced_ffts * read_multiplier is just tile pitch which I'm calculting like 4 different ways here - copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx + tile_size_x, 32); - if ( threadIdx.y == 0 && - (threadIdx.x & 31) >= i_tile * n_complex_vals_per_read && - (threadIdx.x & 31) < (i_tile + 1) * n_complex_vals_per_read && - no_val != copied_val ) { - printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val); - thread_data[thread_data_linear_idx + 1] = copied_val; - } - - __syncwarp( ); - } + // threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept); + + __syncwarp( ); + int producer_thread_re_lane_idx warp_tiler::get_producer_thread_re_lane_idx(i_tile); + + float copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_thread_re_lane_idx, 32); - // increment the physical y in warp - physical_y_in_warp += tile_size_y; + unsigned int thread_data_linear_idx = warp_tiler::thread_data_linear_idx_real_part(i); + + if ( producer_thread_re_lane_idx < warp_tiler::warp_size && warp_tiler::invalid_read_v != copied_val ) { + // printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n", + // threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept); + // printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val); + thread_data[thread_data_linear_idx] = copied_val; + } + __syncwarp( ); + // n_coalesced_ffts * read_multiplier is just tile pitch which I'm calculting like 4 different ways here + + copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx + 1, 32); + if ( producer_thread_re_lane_idx < warp_tiler::warp_size && warp_tiler::invalid_read_v != copied_val ) { + // printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n", + // threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept); + // printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val); + thread_data[thread_data_linear_idx + 1] = copied_val; + } + __syncwarp( ); } } } #endif - static inline __device__ void load_c2r_shared_and_pad(const complex_compute_t* __restrict__ input, - complex_compute_t* __restrict__ shared_mem, - const unsigned int pixel_pitch) { + static inline __device__ void + load_c2r_shared_and_pad(const complex_compute_t* __restrict__ input, + complex_compute_t* __restrict__ shared_mem, + const unsigned int pixel_pitch) { unsigned int index = threadIdx.x + (threadIdx.y * size_of::value); for ( unsigned int i = 0; i < FFT::elements_per_thread / 2; i++ ) {