diff --git a/test/test_warp_scan.cu b/test/test_warp_scan.cu index 6132b5159d..a755422c32 100644 --- a/test/test_warp_scan.cu +++ b/test/test_warp_scan.cu @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. - * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -47,7 +47,7 @@ using namespace cub; // Globals, constants and typedefs //--------------------------------------------------------------------- -static const int NUM_WARPS = 2; +static const int MAX_WARPS = 2; bool g_verbose = false; @@ -253,7 +253,7 @@ __global__ void WarpScanKernel( typedef WarpScan WarpScanT; // Allocate temp storage in shared memory - __shared__ typename WarpScanT::TempStorage temp_storage[NUM_WARPS]; + __shared__ typename WarpScanT::TempStorage temp_storage[MAX_WARPS]; // Get warp index int warp_id = threadIdx.x / LOGICAL_WARP_THREADS; @@ -318,9 +318,10 @@ void Initialize( int logical_warp_items, ScanOpT scan_op, T initial_value, - T warp_aggregates[NUM_WARPS]) + T *warp_aggregates, + int total_warps) { - for (int w = 0; w < NUM_WARPS; ++w) + for (int w = 0; w < total_warps; ++w) { int base_idx = (w * logical_warp_items); int i = base_idx; @@ -358,9 +359,10 @@ void Initialize( int logical_warp_items, ScanOpT scan_op, NullType, - T warp_aggregates[NUM_WARPS]) + T *warp_aggregates, + int total_warps) { - for (int w = 0; w < NUM_WARPS; ++w) + for (int w = 0; w < total_warps; ++w) { int base_idx = (w * logical_warp_items); int i = base_idx; @@ -399,7 +401,10 @@ void Test( InitialValueT initial_value) { enum { - TOTAL_ITEMS = LOGICAL_WARP_THREADS * NUM_WARPS, + IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)), + IS_POW_OF_TWO = ((LOGICAL_WARP_THREADS & (LOGICAL_WARP_THREADS - 1)) == 0), + TOTAL_WARPS = (IS_ARCH_WARP || IS_POW_OF_TWO) ? MAX_WARPS : 1, + TOTAL_ITEMS = LOGICAL_WARP_THREADS * TOTAL_WARPS, }; // Allocate host arrays @@ -408,7 +413,7 @@ void Test( T *h_aggregate = new T[TOTAL_ITEMS]; // Initialize problem - T aggregates[NUM_WARPS]; + T aggregates[MAX_WARPS]; Initialize( gen_mode, @@ -417,7 +422,8 @@ void Test( LOGICAL_WARP_THREADS, scan_op, initial_value, - aggregates); + aggregates, + TOTAL_WARPS); if (g_verbose) { @@ -426,7 +432,7 @@ void Test( printf("\n"); } - for (int w = 0; w < NUM_WARPS; ++w) + for (int w = 0; w < TOTAL_WARPS; ++w) { for (int i = 0; i < LOGICAL_WARP_THREADS; ++i) { @@ -458,7 +464,7 @@ void Test( fflush(stdout); // Run aggregate/prefix kernel - WarpScanKernel<<<1, TOTAL_ITEMS>>>( + WarpScanKernel<<<1, LOGICAL_WARP_THREADS * TOTAL_WARPS>>>( d_in, d_out, d_aggregate, @@ -586,7 +592,6 @@ void Test(GenMode gen_mode) // complex Test(gen_mode, Sum(), TestFoo::MakeTestFoo(17, 21, 32, 85)); Test(gen_mode, Sum(), TestBar(17, 21)); - }