Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #588 from senior-zero/fix-main/github/warp_scan_test
Browse files Browse the repository at this point in the history
Fix warp scan test
  • Loading branch information
gevtushenko authored Nov 1, 2022
2 parents e0fe848 + 2dd78ca commit 4b173be
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions test/test_warp_scan.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -47,7 +47,7 @@ using namespace cub;
// Globals, constants and typedefs
//---------------------------------------------------------------------

static const int NUM_WARPS = 2;
static const int MAX_WARPS = 2;


bool g_verbose = false;
Expand Down Expand Up @@ -253,7 +253,7 @@ __global__ void WarpScanKernel(
typedef WarpScan<T, LOGICAL_WARP_THREADS> WarpScanT;

// Allocate temp storage in shared memory
__shared__ typename WarpScanT::TempStorage temp_storage[NUM_WARPS];
__shared__ typename WarpScanT::TempStorage temp_storage[MAX_WARPS];

// Get warp index
int warp_id = threadIdx.x / LOGICAL_WARP_THREADS;
Expand Down Expand Up @@ -318,9 +318,10 @@ void Initialize(
int logical_warp_items,
ScanOpT scan_op,
T initial_value,
T warp_aggregates[NUM_WARPS])
T *warp_aggregates,
int total_warps)
{
for (int w = 0; w < NUM_WARPS; ++w)
for (int w = 0; w < total_warps; ++w)
{
int base_idx = (w * logical_warp_items);
int i = base_idx;
Expand Down Expand Up @@ -358,9 +359,10 @@ void Initialize(
int logical_warp_items,
ScanOpT scan_op,
NullType,
T warp_aggregates[NUM_WARPS])
T *warp_aggregates,
int total_warps)
{
for (int w = 0; w < NUM_WARPS; ++w)
for (int w = 0; w < total_warps; ++w)
{
int base_idx = (w * logical_warp_items);
int i = base_idx;
Expand Down Expand Up @@ -399,7 +401,10 @@ void Test(
InitialValueT initial_value)
{
enum {
TOTAL_ITEMS = LOGICAL_WARP_THREADS * NUM_WARPS,
IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)),
IS_POW_OF_TWO = ((LOGICAL_WARP_THREADS & (LOGICAL_WARP_THREADS - 1)) == 0),
TOTAL_WARPS = (IS_ARCH_WARP || IS_POW_OF_TWO) ? MAX_WARPS : 1,
TOTAL_ITEMS = LOGICAL_WARP_THREADS * TOTAL_WARPS,
};

// Allocate host arrays
Expand All @@ -408,7 +413,7 @@ void Test(
T *h_aggregate = new T[TOTAL_ITEMS];

// Initialize problem
T aggregates[NUM_WARPS];
T aggregates[MAX_WARPS];

Initialize(
gen_mode,
Expand All @@ -417,7 +422,8 @@ void Test(
LOGICAL_WARP_THREADS,
scan_op,
initial_value,
aggregates);
aggregates,
TOTAL_WARPS);

if (g_verbose)
{
Expand All @@ -426,7 +432,7 @@ void Test(
printf("\n");
}

for (int w = 0; w < NUM_WARPS; ++w)
for (int w = 0; w < TOTAL_WARPS; ++w)
{
for (int i = 0; i < LOGICAL_WARP_THREADS; ++i)
{
Expand Down Expand Up @@ -458,7 +464,7 @@ void Test(
fflush(stdout);

// Run aggregate/prefix kernel
WarpScanKernel<LOGICAL_WARP_THREADS, TEST_MODE><<<1, TOTAL_ITEMS>>>(
WarpScanKernel<LOGICAL_WARP_THREADS, TEST_MODE><<<1, LOGICAL_WARP_THREADS * TOTAL_WARPS>>>(
d_in,
d_out,
d_aggregate,
Expand Down Expand Up @@ -586,7 +592,6 @@ void Test(GenMode gen_mode)
// complex
Test<LOGICAL_WARP_THREADS>(gen_mode, Sum(), TestFoo::MakeTestFoo(17, 21, 32, 85));
Test<LOGICAL_WARP_THREADS>(gen_mode, Sum(), TestBar(17, 21));

}


Expand Down

0 comments on commit 4b173be

Please sign in to comment.