Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Optimize decouled look-back
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Dec 22, 2022
1 parent 73f3434 commit d94cc04
Show file tree
Hide file tree
Showing 4 changed files with 553 additions and 44 deletions.
1 change: 1 addition & 0 deletions .clangd
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ CompileFlags:
# report all errors
- "-ferror-limit=0"
- "-ftemplate-backtrace-limit=0"
- "-stdlib=libc++"
Remove:
- -stdpar
# strip CUDA fatbin args
Expand Down
135 changes: 91 additions & 44 deletions cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@
#include <iterator>

#include <cub/config.cuh>
#include <cub/detail/strong_load.cuh>
#include <cub/detail/strong_store.cuh>
#include <cub/detail/uninitialized_copy.cuh>
#include <cub/thread/thread_load.cuh>
#include <cub/thread/thread_store.cuh>
#include <cub/util_device.cuh>
#include <cub/warp/warp_reduce.cuh>

#include <nv/target>

CUB_NAMESPACE_BEGIN


Expand Down Expand Up @@ -106,6 +110,44 @@ enum ScanTileStatus
SCAN_TILE_INCLUSIVE, // Inclusive tile prefix is available
};

namespace detail
{

template <int Delay, unsigned int GridThreshold = 500>
__device__ __forceinline__ void delay()
{
NV_IF_TARGET(NV_PROVIDES_SM_70,
(if (Delay > 0)
{
if (gridDim.x < GridThreshold)
{
__threadfence_block();
}
else
{
__nanosleep(Delay);
}
}));
}

template <int Delay = 350, unsigned int GridThreshold = 500>
__device__ __forceinline__ void delay_or_prevent_hoisting()
{
NV_IF_TARGET(NV_PROVIDES_SM_70,
(delay<Delay, GridThreshold>();),
(__threadfence_block();));
}

template <int Delay = 350, unsigned int GridThreshold = 500>
__device__ __forceinline__ void delay_on_dc_gpu_or_prevent_hoisting()
{
NV_DISPATCH_TARGET(
NV_IS_EXACTLY_SM_80, (delay<Delay, GridThreshold>();),
NV_PROVIDES_SM_70, (delay< 0, GridThreshold>();),
NV_IS_DEVICE, (__threadfence_block();));
}

}

/**
* Tile status interface.
Expand All @@ -127,20 +169,20 @@ struct ScanTileState<T, true>
// Status word type
using StatusWord = cub::detail::conditional_t<
sizeof(T) == 8,
long long,
unsigned long long,
cub::detail::conditional_t<
sizeof(T) == 4,
int,
cub::detail::conditional_t<sizeof(T) == 2, short, char>>>;
unsigned int,
cub::detail::conditional_t<sizeof(T) == 2, unsigned short, unsigned char>>>;

// Unit word type
using TxnWord = cub::detail::conditional_t<
sizeof(T) == 8,
longlong2,
ulonglong2,
cub::detail::conditional_t<
sizeof(T) == 4,
int2,
cub::detail::conditional_t<sizeof(T) == 2, int, uchar2>>>;
uint2,
unsigned int>>;

// Device word type
struct TileDescriptor
Expand Down Expand Up @@ -230,7 +272,8 @@ struct ScanTileState<T, true>

TxnWord alias;
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);

detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
}


Expand All @@ -245,7 +288,8 @@ struct ScanTileState<T, true>

TxnWord alias;
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);

detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
}

/**
Expand All @@ -257,13 +301,18 @@ struct ScanTileState<T, true>
T &value)
{
TileDescriptor tile_descriptor;
do

{
__threadfence_block(); // prevent hoisting loads from loop
TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
TxnWord alias = detail::load_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
}

} while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff));
while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff))
{
detail::delay_or_prevent_hoisting();
TxnWord alias = detail::load_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
}

status = tile_descriptor.status;
value = tile_descriptor.value;
Expand All @@ -281,7 +330,7 @@ template <typename T>
struct ScanTileState<T, false>
{
// Status word type
typedef char StatusWord;
using StatusWord = unsigned int;

// Constants
enum
Expand Down Expand Up @@ -382,12 +431,7 @@ struct ScanTileState<T, false>
{
// Update tile inclusive value
ThreadStore<STORE_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx, tile_inclusive);

// Fence
__threadfence();

// Update tile status
ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_INCLUSIVE));
detail::store_release(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_INCLUSIVE));
}


Expand All @@ -398,12 +442,7 @@ struct ScanTileState<T, false>
{
// Update tile partial value
ThreadStore<STORE_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx, tile_partial);

// Fence
__threadfence();

// Update tile status
ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_PARTIAL));
detail::store_release(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_PARTIAL));
}

/**
Expand All @@ -414,17 +453,21 @@ struct ScanTileState<T, false>
StatusWord &status,
T &value)
{
do {
status = ThreadLoad<LOAD_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx);

__threadfence(); // prevent hoisting loads from loop or loads below above this one
do
{
status = detail::load_relaxed(d_tile_status + TILE_STATUS_PADDING + tile_idx);
__threadfence();

} while (status == SCAN_TILE_INVALID);
} while (WARP_ANY((status == SCAN_TILE_INVALID), 0xffffffff));

if (status == StatusWord(SCAN_TILE_PARTIAL))
value = ThreadLoad<LOAD_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx);
{
value = ThreadLoad<LOAD_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx);
}
else
value = ThreadLoad<LOAD_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx);
{
value = ThreadLoad<LOAD_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx);
}
}
};

Expand Down Expand Up @@ -471,7 +514,7 @@ template <
typename KeyT>
struct ReduceByKeyScanTileState<ValueT, KeyT, true>
{
typedef KeyValuePair<KeyT, ValueT>KeyValuePairT;
using KeyValuePairT = KeyValuePair<KeyT, ValueT>;

// Constants
enum
Expand All @@ -486,17 +529,17 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
// Status word type
using StatusWord = cub::detail::conditional_t<
STATUS_WORD_SIZE == 8,
long long,
unsigned long long,
cub::detail::conditional_t<
STATUS_WORD_SIZE == 4,
int,
cub::detail::conditional_t<STATUS_WORD_SIZE == 2, short, char>>>;
unsigned int,
cub::detail::conditional_t<STATUS_WORD_SIZE == 2, unsigned short, unsigned char>>>;

// Status word type
using TxnWord = cub::detail::conditional_t<
TXN_WORD_SIZE == 16,
longlong2,
cub::detail::conditional_t<TXN_WORD_SIZE == 8, long long, int>>;
ulonglong2,
cub::detail::conditional_t<TXN_WORD_SIZE == 8, unsigned long long, unsigned int>>;

// Device word type (for when sizeof(ValueT) == sizeof(KeyT))
struct TileDescriptorBigStatus
Expand Down Expand Up @@ -594,7 +637,8 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>

TxnWord alias;
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);

detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
}


Expand All @@ -610,7 +654,8 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>

TxnWord alias;
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);

detail::store_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
}

/**
Expand All @@ -637,11 +682,12 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
// value.key = tile_descriptor.key;

TileDescriptor tile_descriptor;

do
{
__threadfence_block(); // prevent hoisting loads from loop
TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
detail::delay_on_dc_gpu_or_prevent_hoisting();
TxnWord alias = detail::load_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);

} while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff));

Expand Down Expand Up @@ -750,6 +796,7 @@ struct TilePrefixCallbackOp
T window_aggregate;

// Wait for the warp-wide window of predecessor tiles to become valid
detail::delay<450>();
ProcessWindow(predecessor_idx, predecessor_status, window_aggregate);

// The exclusive tile prefix starts out as the current window aggregate
Expand Down
Loading

0 comments on commit d94cc04

Please sign in to comment.