Skip to content

Commit

Permalink
Multi-pass scan
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Feb 11, 2024
1 parent 8b9cbeb commit ddcc642
Showing 1 changed file with 199 additions and 4 deletions.
203 changes: 199 additions & 4 deletions Src/Base/AMReX_Scan.H
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE type, RetSum a_ret_sum = retSum

#ifndef AMREX_SYCL_NO_MULTIPASS_SCAN
if (nblocks > 1) {
return PrefixSum_mp<T>(n, std::forward<FIN>(fin), std::forward<FOUT>(fout), type, retSum);
return PrefixSum_mp<T>(n, std::forward<FIN>(fin), std::forward<FOUT>(fout), type, a_ret_sum);
}
#endif

Expand Down Expand Up @@ -621,7 +621,179 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE type, RetSum a_ret_sum = retSum
return totalsum;
}

#elif defined(AMREX_USE_HIP)
#else // #if defined(AMREX_USE_SYCL)

#define AMREX_GPU_MULTIPASS_SCAN 1

#if defined(AMREX_GPU_MULTIPASS_SCAN)
template <int depth, typename T, typename N, typename FIN, typename FOUT, typename TYPE>
T PrefixSum_mp (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum)
{
if (n <= 0) { return 0; }
#if defined(AMREX_USE_HIP)
constexpr int nwarps_per_block = 4;
#else
constexpr int nwarps_per_block = 8;
#endif
constexpr int nthreads = nwarps_per_block*Gpu::Device::warp_size;
constexpr int nelms_per_thread = 12;
constexpr int nelms_per_block = nthreads * nelms_per_thread;
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(std::numeric_limits<int>::max())*nelms_per_block);
int nblocks = (static_cast<Long>(n) + nelms_per_block - 1) / nelms_per_block;
std::size_t sm = 0;
auto stream = Gpu::gpuStream();

std::size_t nbytes_blockresult = Arena::align(sizeof(T)*n);
std::size_t nbytes_blocksum = Arena::align(sizeof(T)*nblocks);
std::size_t nbytes_totalsum = Arena::align(sizeof(T));
auto dp = (char*)(The_Arena()->alloc(nbytes_blockresult
+ nbytes_blocksum
+ nbytes_totalsum));
T* blockresult_p = (T*)dp;
T* blocksum_p = (T*)(dp + nbytes_blockresult);
T* totalsum_p = (T*)(dp + nbytes_blockresult + nbytes_blocksum);

amrex::launch(nblocks, nthreads, sm, stream,
[=] AMREX_GPU_DEVICE () noexcept
{
// Each block processes [ibegin,iend).
N ibegin = nelms_per_block * blockIdx.x;
N iend = amrex::min(static_cast<N>(ibegin+nelms_per_block), n);

T block_agg;
T data[nelms_per_thread];

constexpr bool is_exclusive = std::is_same<std::decay_t<TYPE>,Type::Exclusive>::value;

#if defined(AMREX_USE_CUDA)

using BlockLoad = cub::BlockLoad<T, nthreads, nelms_per_thread, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockScan = cub::BlockScan<T, nthreads, cub::BLOCK_SCAN_WARP_SCANS>;
using BlockExchange = cub::BlockExchange<T, nthreads, nelms_per_thread>;

__shared__ union TempStorage
{
typename BlockLoad::TempStorage load;
typename BlockExchange::TempStorage exchange;
typename BlockScan::TempStorage scan;
} temp_storage;

auto input_lambda = [&] (N i) -> T { return fin(i+ibegin); };
cub::TransformInputIterator<T,decltype(input_lambda),cub::CountingInputIterator<N> >
input_begin(cub::CountingInputIterator<N>(0), input_lambda);

if (static_cast<int>(iend-ibegin) == nelms_per_block) {
BlockLoad(temp_storage.load).Load(input_begin, data);
} else {
BlockLoad(temp_storage.load).Load(input_begin, data, iend-ibegin, 0); // padding with 0
}

__syncthreads();

AMREX_IF_CONSTEXPR(is_exclusive) {
BlockScan(temp_storage.scan).ExclusiveSum(data, data, block_agg);
} else {
BlockScan(temp_storage.scan).InclusiveSum(data, data, block_agg);
}

__syncthreads();

BlockExchange(temp_storage.exchange).BlockedToStriped(data);

#else

using BlockLoad = rocprim::block_load<T, nthreads, nelms_per_thread,
rocprim::block_load_method::block_load_transpose>;
using BlockScan = rocprim::block_scan<T, nthreads,
rocprim::block_scan_algorithm::using_warp_scan>;
using BlockExchange = rocprim::block_exchange<T, nthreads, nelms_per_thread>;

__shared__ union TempStorage {
typename BlockLoad::storage_type load;
typename BlockExchange::storage_type exchange;
typename BlockScan::storage_type scan;
} temp_storage;

auto input_begin = rocprim::make_transform_iterator(
rocprim::make_counting_iterator(N(0)),
[&] (N i) -> T { return fin(i+ibegin); });

if (static_cast<int>(iend-ibegin) == nelms_per_block) {
BlockLoad().load(input_begin, data, temp_storage.load);
} else {
BlockLoad().load(input_begin, data, iend-ibegin, 0, temp_storage.load); // padding with 0
}

__syncthreads();

AMREX_IF_CONSTEXPR(is_exclusive) {
BlockScan().exclusive_scan(data, data, T{0}, block_agg, temp_storage.scan);
} else {
BlockScan().inclusive_scan(data, data, block_agg, temp_storage.scan);
}

__syncthreads();

BlockExchange().blocked_to_striped(data, data, temp_storage.exchange);

#endif

for (int i = 0; i < nelms_per_thread; ++i) {
N offset = ibegin + i*blockDim.x + threadIdx.x;
if (offset < iend) {
blockresult_p[offset] = data[i];
}
}

if (threadIdx.x == 0) {
if (nblocks == 1) {
*totalsum_p = block_agg;
}
blocksum_p[blockIdx.x] = block_agg;
}
});

T totalsum = 0;
if (nblocks > 1) {
if constexpr (depth < 2) {
totalsum = PrefixSum_mp<depth+1,T>(nblocks,
[=] AMREX_GPU_DEVICE (int i)
{ return blocksum_p[i]; },
[=] AMREX_GPU_DEVICE (int i, const int& s)
{ blocksum_p[i] = s; },
Type::exclusive, a_ret_sum);
} else {
amrex::Abort("PrefixSum_mp: recursion is too deep");
The_Arena()->free(dp);
return totalsum;
}
}

amrex::launch(nblocks, nthreads, 0, stream,
[=] AMREX_GPU_DEVICE () noexcept
{
// Each block processes [ibegin,iend).
N ibegin = nelms_per_block * blockIdx.x;
N iend = amrex::min(static_cast<N>(ibegin+nelms_per_block), n);
T prev_sum = (blockIdx.x == 0) ? 0 : blocksum_p[blockIdx.x];
for (N offset = ibegin + threadIdx.x; offset < iend; offset += blockDim.x) {
fout(offset, prev_sum + blockresult_p[offset]);
}
});

if (a_ret_sum && nblocks == 1) {
Gpu::dtoh_memcpy_async(&totalsum, totalsum_p, sizeof(T));
}
Gpu::streamSynchronize();
The_Arena()->free(dp);

AMREX_GPU_ERROR_CHECK();

return totalsum;
}
#endif // #if defined(AMREX_GPU_MULTIPASS_SCAN)

#if defined(AMREX_USE_HIP)

template <typename T, typename N, typename FIN, typename FOUT, typename TYPE,
typename M=std::enable_if_t<std::is_integral<N>::value &&
Expand All @@ -634,7 +806,15 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum = retSum)
constexpr int nthreads = nwarps_per_block*Gpu::Device::warp_size; // # of threads per block
constexpr int nelms_per_thread = sizeof(T) >= 8 ? 8 : 16;
constexpr int nelms_per_block = nthreads * nelms_per_thread;
int nblocks = (n + nelms_per_block - 1) / nelms_per_block;
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(std::numeric_limits<int>::max())*nelms_per_block);
int nblocks = (Long(n) + nelms_per_block - 1) / nelms_per_block;

#if defined(AMREX_GPU_MULTIPASS_SCAN)
if (nblocks > 1) {
return PrefixSum_mp<0,T>(n, std::forward<FIN>(fin), std::forward<FOUT>(fout), TYPE{}, a_ret_sum);
}
#endif

std::size_t sm = 0;
auto stream = Gpu::gpuStream();

Expand Down Expand Up @@ -791,6 +971,12 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum = retSum)
ScanTileState tile_state;
tile_state.Init(nblocks, tile_state_p, tile_state_size); // Init ScanTileState on host

#if defined(AMREX_GPU_MULTIPASS_SCAN)
if (nblocks > 1) {
return PrefixSum_mp<0,T>(n, std::forward<FIN>(fin), std::forward<FOUT>(fout), TYPE{}, a_ret_sum);
}
#endif

if (nblocks > 1) {
// Init ScanTileState on device
amrex::launch((nblocks+nthreads-1)/nthreads, nthreads, 0, stream, [=] AMREX_GPU_DEVICE ()
Expand Down Expand Up @@ -912,6 +1098,13 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum = retSum)
constexpr int nelms_per_block = nthreads * nchunks;
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(std::numeric_limits<int>::max())*nelms_per_block);
int nblocks = (static_cast<Long>(n) + nelms_per_block - 1) / nelms_per_block;

#if defined(AMREX_GPU_MULTIPASS_SCAN)
if (nblocks > 1) {
return PrefixSum_mp<0,T>(n, std::forward<FIN>(fin), std::forward<FOUT>(fout), TYPE{}, a_ret_sum);
}
#endif

std::size_t sm = sizeof(T) * (Gpu::Device::warp_size + nwarps_per_block) + sizeof(int);
auto stream = Gpu::gpuStream();

Expand Down Expand Up @@ -1141,7 +1334,9 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE, RetSum a_ret_sum = retSum)
return totalsum;
}

#endif
#endif // #if defined(AMREX_USE_HIP)

#endif // #if defined(AMREX_USE_SYCL)

// The return value is the total sum if a_ret_sum is true.
template <typename N, typename T, typename M=std::enable_if_t<std::is_integral<N>::value> >
Expand Down

0 comments on commit ddcc642

Please sign in to comment.