Skip to content

Commit

Permalink
drafted warp tiler struct to simplify things.
Browse files Browse the repository at this point in the history
  • Loading branch information
bHimes committed Dec 21, 2024
1 parent 5a43474 commit ed18314
Showing 1 changed file with 171 additions and 79 deletions.
250 changes: 171 additions & 79 deletions include/FastFFT.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,146 @@ inline __device__ SetTo_t convert_if_needed(const GetFrom_t* __restrict__ ptr, c
}
}

//////////////////////////////////////////////
// IO functions adapted from the cufftdx examples
///////////////////////////////
template <bool flag = false>
inline void static_assert_allowed_read_size( ) { static_assert(flag, "static_assert_allowed_read_size"); }

template <class FFT, unsigned int max_threads_per_block = FFT::max_threads_per_block, unsigned int n_coalesced_ffts = 1>
struct WarpTiler {
using complex_compute_t = typename FFT::value_type;
using scalar_compute_t = typename complex_compute_t::value_type;

// Drafting for C2R XY specifically

// Describe in terms of 2D as 3D is handled as a stack of 2D

// Some steps are substantially simplified if all threads that are producers are also consumers. This means - whoops only half the threads will be consumers then
// we need blockDim.x >= 32, and with EPT restricted to pow 2, min size is 64
// For these smaller arrays, L2 cache hits are far more likely, can be encouraged further > 8.0 or alternatively handled
// by a intra-block smem transpose if the thread fft api is used. In any case, for the current project, we mostly care about large
// FFts anyway

// We have an array of M complex valued FFTs, all with their 0th index along the fast dimension (for this example)
// The FFT dimension is N and for C2R in a natural layout we have N/2 + 1 complex inputs per FFT
// blockDim.x is fixed to be N / ept (FFT::elements_per_thread) = FFT::stride
// For C2R we have ept/2+1 = input_ept, and have at most that many read cycles/FFT, while many threads are excluded on the last read (We may be able to fix that with this method)
// Each read cycle has blockDim.x reads of contigous threads and each thread has an array T[FFT::storage_size] where the data is expected to be T[i] = data[tid.x + i*FFT::stride]

// For block methods _XY the data are transposed in memory and so the reads are not coalesced
// Normally, we have gridDim.y blocks, but we reduce that here by n_coalesced_ffts

// We read in values in 4byte / thread to maximized coalescing and also b/c I think that is all a warp shuffle supports (though I could be wrong on that.)

// The goal is to have n_coalesced_ffts as large as possible without killing occupancy due to register pressure

// each read cycle, a warp will be have to read n_coalseced_ffts * 2 4byte chunks leaving warp_size (32) / n_coalseced_ffts * 2 complex elements read for each FFT

// We require that blockDim.x = FFT::stride = n_expected reads / cycle be >= this value. EPT is restricted to a power of 2, so if it is > than it must be so by at least 2 and it will be a nice whole number

// FIXME: need to add checks on data layout (RIRI vs RRII)
// FIXME: need to add checks on data sizes and allow more flexiblity
constexpr unsigned int read_bytes = 4;
static_assert(sizeof(complex_compute_t) == 2 * read_bytes, "warp_tiler requires complex_compute_t to be 8 bytes atm.");
static_assert(sizeof(scalar_compute_t) == read_bytes, "warp_tiler requires scalar_compute_t to be 4 bytes atm.");
constexpr unsigned int read_multiplier = 2;

constexpr unsigned int expected_reads_per_cycle = FFT::stride;
constexpr unsigned int warp_size = 32; // I think this is true for all arch as of 2024 FIXME: confirm

constexpr unsigned int n_4byte_reads_per_complex_element =
sizeof(scalar_compute_t) == read_bytes ? 2 : static_assert_allowed_read_size( ); // overkill for now, but for handling half later

constexpr bool strided_XY = true; // for later configuration;
constexpr unsigned int tile_thread_dim_fast = n_coalesced_ffts * n_4byte_reads_per_complex_element;
constexpr unsigned int tile_thread_dim_strided = warp_size / tile_thread_dim_fast;

static_assert(tile_thread_dim_strided <= expected_reads_per_cycle, "there are not enough x threads for this method.");
static_assert(expected_reads_per_cycle >= warp_size, "Min blockDim.x 32 for this optimization");

// some aliases but I guess we could just omit the above two lines.
constexpr unsigned int tile_thread_dim_x = strided_XY ? tile_thread_dim_fast : tile_thread_dim_strided;
constexpr unsigned int tile_thread_dim_y = strided_XY ? tile_thread_dim_strided : tile_thread_dim_fast;

constexpr unsigned int n_tile_reads_per_cycle = expected_reads_per_cycle / tile_thread_dim_strided;

// Some threads may not have a valid read in a given cycle, but we other threads won't know that.
// Even a thread without a valid read on the tile, may be a consumer through a warp shuffle, so we can't exclude it with logic.
// Instead, use a set value to communicate this
constexpr float invalid_read_v =
sizeof(scalar_compute_t) == read_bytes ? scalar_compute_t{-std::numeric_limits<float>::max( )} : static_assert_allowed_read_size( ); // overkill for now, but for handling half later

const unsigned int lane_idx = get_lane_id( ); // We may have threads coming from x/y so we get the lane ide from special reg, which could be a tiny bit slower, but more dependible than : threadIdx.x & 31;

// This should be safe b/c if blockDim.x >= 32 we should have blockDim.y = 1, For now, I have an assert prior to the kernel launch to save overhead
MyFFTDebugAssertTestTrue(blockDim.x* blockDim.y == 0 || blockDim.x * blockDim.y == 32, "Incorrect partitioning of x/y threads for warp tiler.");
const unsigned int tile_idx = blockDim.y == 0 ? threadIdx.x / warp_size : 0;

// TOOD: rename and make invertable
const unsigned int physical_y_in_tile = lane_idx / tile_thread_dim_x;
const unsigned int physical_x_in_tile = line_idx & (warp_size - 1);

// in the normal approach this will be threadIdx.x + i*FFT::stride (which we'll need to define the consumer thread)
// I think this will be at most 2 instuctions IMAD, but maybe it is better to calc an intial value and increment it
static inline __device__ unsigned int data_index_to_read_1d(const unsigned int i, const unsigned int i_read) {
return (physical_y_in_tile + tile_idx * warp_size) + i_read * n_tile_reads_per_cycle + i * FFT::stride;
}

// FIXME: there should be an assert in the main program prior to launching this kernel that things are square
// leaving it templated here so it is more obvious if those restrictions are relaxed in the future.
template <unsigned int PixelPitch = 0>
static inline __device__ unsigned int data_index_to_read_2d(const unsigned int data_index_to_read_1d_value) {
if constexpr ( PixelPitch == 0 ) {
// Generally we have a square FFT, so we can infer the pixel pitch at compile time
return (size_of<FFT>::value * read_multiplier) * data_index_to_read_1d_value + warp_tiler::physical_x_in_tile
}
else {
return PixelPitch * data_index_to_read_1d_value + warp_tiler::physical_x_in_tile;
}
}

static inline __device__ unsigned int thread_data_linear_idx_real_part(const unsigned int i) {
// The consumer's i_fft (index of coalesced FFT) is just physical_x_in_tile (FIXME, may be C2R XY specific here)
return (read_multiplier * FFT::storage_size * physical_x_in_tile / 2) + read_multiplier * i;
}

// This can't be right, which means data_index_to_read_1d must be wrong
static inline __device__ int get_producer_thread_tile_index_y(const unsigned int i_read) {
// A consumer thread will be wanting the data from input_data[threadIdx.x + i*FFT::stride]
// The producer thread that fetched this index was (physical_y_in_tile + tile_idx * warp_size) + i_read * n_tile_reads_per_cycle + i * FFT::stride;
// Solving for physical_y_in_tile = threadIdx.x - tile_idx * warp_size - i_read * n_tile_reads_per_cycle

// This will only be a valid index for a return value >= 0 < n_tile_reads_per_cycle which determines if we are an active consumer thread
return threadIdx.x - tile_idx * warp_size - i_read * n_tile_reads_per_cycle;
}

static inline __device__ unsigned int get_my_fft_idx( ) {
return physical_x_in_tile / 2;
}

// the real part will be at the even (physical_x_in_tile, thread_tile_index_y)
// The imaginary part will just be + 1
static inline __device__ unsigned int get_producer_thread_re_lane_idx(const unsigned int i_read) {
int producer_thread_tile_index_y = get_producer_thread_tile_index_y(i_read);
if ( producer_thread_tile_index_y < 0 || n_tile_reads_per_cycle >= n_tile_reads_per_cycle ) {
return warp_size;
}
else {
// just a conversion from a 2d index in the tile to a 1d index == producers lane
return read_multiplier * get_my_fft_idx( ) + producer_thread_tile_index_y * tile_thread_dim_x;
}
}

// Now every thread gets the producer value
// Every thread reads with a shfl_sync to a register copied value
// if (producer lane < warp_size) -> write to thread array
// maybe warpsync?
// Every thread reads again, but + 1
// if (producer lane)

}

template <class FFT, unsigned int max_threads_per_block = FFT::max_threads_per_block, unsigned int n_coalesced_ffts = 1>
struct io {

using complex_compute_t = typename FFT::value_type;
using scalar_compute_t = typename complex_compute_t::value_type;

Expand Down Expand Up @@ -891,96 +1025,54 @@ struct io {
const unsigned int pixel_pitch,
const unsigned int SignalLength = FFT::input_length) {

// Work this in steps toward generality
// Here we assume the input originally gridDim.y * FFT::input_length (complex) and there are blockDim.x = FFT::stride (y,z are 1)
// We need read multiplier because we can only swap 4 bytes (afaik) so we'll read re/im separately
constexpr unsigned int read_multiplier = 2; // to make it clear where we are indexing extra because we are reading re/im seperately.

// If we have fewer than

const unsigned int lane_idx = get_lane_id( ); // We may have threads coming from x/y so we get the lane ide from special reg, which could be a tiny bit slower, but more dependible than : threadIdx.x & 31;
// This should be safe b/c if blockDim.x >= 32 we should have blockDim.y = 1, For now, I have an assert prior to the kernel launch to save overhead
const unsigned int warp_idx = threadIdx.x / 32;
const unsigned int logical_x_in_1d = (threadIdx.x & 31);
// Normally we would have a gridDim.y = pixel_pitch, but since here we have reduced the number of blocks to
// pixel_pitch / n_coalesced_ffts,
const unsigned int tile_idx_x = lane_idx % (n_coalesced_ffts * read_multiplier);
const unsigned int physical_x = tile_idx_x + (blockIdx.y * n_coalesced_ffts * read_multiplier);

constexpr unsigned int tile_size_y = 16 / n_coalesced_ffts; // # threads in each dimension 4
constexpr unsigned int tile_size_x = 32 / tile_size_y; // 8
constexpr unsigned int n_complex_vals_per_read = tile_size_y / 2; // 2
constexpr unsigned int n_sub_warp_blocks = 32 / n_complex_vals_per_read; // also pitch in elements for the tile 16
constexpr float no_val = scalar_compute_t{-std::numeric_limits<float>::max( )};
const unsigned int expected_reads_per_i = blockDim.x; // also can get at compile time 4 (8 threads in y)

// In the normal implemenation there is only on read per loop, and every thread is both a producer and a consumer.
// so for 64 with 8 ept, there are 8 threads and a stride of 8 In this implementation there is a minimum of 32 threads, so
WarpTiler warp_tiler<FFT, max_threads_per_block, n_coalesced_ffts>( );
//revert FFT::input_ept
for ( unsigned int i = 0; i < FFT::input_ept; i++ ) {

unsigned int physical_y_in_warp = lane_idx / tile_size_x; // 0 1 2 3
unsigned int base_physical_y_in_warp = physical_y_in_warp;
// All reads need to be within the warp to use shfl_sync, so we need to break up the input data into 2d blocks,
for ( unsigned int i_tile = 0; i_tile < std::min(std::max(1u, expected_reads_per_i / n_complex_vals_per_read), n_sub_warp_blocks); i_tile++ ) {
for ( unsigned int i_tile = 0; i_tile < warp_tiler::n_tile_reads_per_cycle; i_tile++ ) {
// 4 / 2 = 2, min(2, 16) = 2
// read in as floats
unsigned int read_from_data_index = (physical_y_in_warp + warp_idx * 32) + i * FFT::stride;
// i0 0: 0 1 2 3
// 1: 4 5 6 7
// i 1: 4 5 6 7 (8 9 10 11)
const float read_val =
read_from_data_index < SignalLength ? reinterpret_cast<const float*>(input)[read_from_data_index * pixel_pitch + physical_x] : no_val;
const unsigned int data_index_to_read_1d = warp_tiler::data_index_to_read_1d(i, i_tile);
const float read_val =
data_index_to_read_1d < SignalLength ? reinterpret_cast<const float*>(input)[warp_tiler::data_index_to_read_2d(data_index_to_read_1d)] : warp_tiler::invalid_read_v;

// if ( no_val != read_val )
// printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n",
// threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, read_from_data_index, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept);

for ( int i_fft = 0; i_fft < n_coalesced_ffts; i_fft++ ) {
// all threads in the warp will now have data and we need to know who gets that data
// this will result in some threads calculating something they don't need, but thats free
// fft in the linear storage = FFT::storage_size * i_fft * 2 (2 because we are reading as floats)
// + 2 * i (2 because we are reading as floats)
unsigned int thread_data_linear_idx = read_multiplier * (FFT::storage_size * i_fft + i);
// We may have read in a dummy value if we are beyond the signal length
__syncwarp( );

unsigned int producer_tidx = i_fft + 2 * tile_size_x * (base_physical_y_in_warp / 2);
float copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx, 32);

if ( threadIdx.y == 0 &&
(threadIdx.x & 31) >= i_tile * n_complex_vals_per_read &&
(threadIdx.x & 31) < (i_tile + 1) * n_complex_vals_per_read &&
no_val != copied_val ) {
printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n",
threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, read_from_data_index, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept);
printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val);
thread_data[thread_data_linear_idx] = copied_val;
}
__syncwarp( );
// n_coalesced_ffts * read_multiplier is just tile pitch which I'm calculting like 4 different ways here
copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx + tile_size_x, 32);
if ( threadIdx.y == 0 &&
(threadIdx.x & 31) >= i_tile * n_complex_vals_per_read &&
(threadIdx.x & 31) < (i_tile + 1) * n_complex_vals_per_read &&
no_val != copied_val ) {
printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val);
thread_data[thread_data_linear_idx + 1] = copied_val;
}

__syncwarp( );
}
// threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept);

__syncwarp( );
int producer_thread_re_lane_idx warp_tiler::get_producer_thread_re_lane_idx(i_tile);

float copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_thread_re_lane_idx, 32);

// increment the physical y in warp
physical_y_in_warp += tile_size_y;
unsigned int thread_data_linear_idx = warp_tiler::thread_data_linear_idx_real_part(i);

if ( producer_thread_re_lane_idx < warp_tiler::warp_size && warp_tiler::invalid_read_v != copied_val ) {
// printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n",
// threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept);
// printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val);
thread_data[thread_data_linear_idx] = copied_val;
}
__syncwarp( );
// n_coalesced_ffts * read_multiplier is just tile pitch which I'm calculting like 4 different ways here

copied_val = __shfl_sync(0xFFFFFFFF, read_val, producer_tidx + 1, 32);
if ( producer_thread_re_lane_idx < warp_tiler::warp_size && warp_tiler::invalid_read_v != copied_val ) {
// printf("tidx: %i tidy: %i blockIDx.y: %i read val: %2.2f lane: %i i: %i warp: %i tile_idx_x: %i x: %i y: %i readidx: %i n_sub_warp_blocks: %i n_coalesced_ffts: %i FFT::input_ept: %i\n",
// threadIdx.x, threadIdx.y, blockIdx.y, read_val, lane_idx, i, i_tile, tile_idx_x, physical_x, physical_y_in_warp, data_index_to_read_1d, n_sub_warp_blocks, n_coalesced_ffts, FFT::input_ept);
// printf("x: %i, y:%i l:%i from: %i, readVal:%3.3f, copiedVal:%3.3f\n", physical_x, base_physical_y_in_warp, producer_tidx, read_val, copied_val);
thread_data[thread_data_linear_idx + 1] = copied_val;
}
__syncwarp( );
}
}
}
#endif

static inline __device__ void load_c2r_shared_and_pad(const complex_compute_t* __restrict__ input,
complex_compute_t* __restrict__ shared_mem,
const unsigned int pixel_pitch) {
static inline __device__ void
load_c2r_shared_and_pad(const complex_compute_t* __restrict__ input,
complex_compute_t* __restrict__ shared_mem,
const unsigned int pixel_pitch) {

unsigned int index = threadIdx.x + (threadIdx.y * size_of<FFT>::value);
for ( unsigned int i = 0; i < FFT::elements_per_thread / 2; i++ ) {
Expand Down

0 comments on commit ed18314

Please sign in to comment.