diff --git a/cub/util_type.cuh b/cub/util_type.cuh index 4d6155963f..c38051078e 100644 --- a/cub/util_type.cuh +++ b/cub/util_type.cuh @@ -40,6 +40,9 @@ #if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__ #include #endif +#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__ + #include +#endif #include "util_macro.cuh" #include "util_arch.cuh" @@ -1079,6 +1082,21 @@ struct FpLimits<__half> }; #endif +#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__ +template <> +struct FpLimits<__nv_bfloat16> +{ + static __host__ __device__ __forceinline__ __nv_bfloat16 Max() { + unsigned short max_word = 0x7F7F; + return reinterpret_cast<__nv_bfloat16&>(max_word); + } + + static __host__ __device__ __forceinline__ __nv_bfloat16 Lowest() { + unsigned short lowest_word = 0xFF7F; + return reinterpret_cast<__nv_bfloat16&>(lowest_word); + } +}; +#endif /** * Basic type traits (fp primitive specialization) @@ -1146,6 +1164,9 @@ template <> struct NumericTraits : BaseTraits= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__ template <> struct NumericTraits<__half> : BaseTraits {}; #endif +#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__ + template <> struct NumericTraits<__nv_bfloat16> : BaseTraits {}; +#endif template <> struct NumericTraits : BaseTraits::VolatileWord, bool> {}; diff --git a/test/bfloat16.h b/test/bfloat16.h new file mode 100644 index 0000000000..4413f9145d --- /dev/null +++ b/test/bfloat16.h @@ -0,0 +1,239 @@ +/****************************************************************************** + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * Utilities for interacting with the opaque CUDA __nv_bfloat16 type + */ + +#include +#include +#include + +#include + +#ifdef __GNUC__ +// There's a ton of type-punning going on in this file. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + + +/****************************************************************************** + * bfloat16_t + ******************************************************************************/ + +/** + * Host-based fp16 data type compatible and convertible with __nv_bfloat16 + */ +struct bfloat16_t +{ + uint16_t __x; + + /// Constructor from __nv_bfloat16 + __host__ __device__ __forceinline__ + bfloat16_t(const __nv_bfloat16 &other) + { + __x = reinterpret_cast(other); + } + + /// Constructor from integer + __host__ __device__ __forceinline__ + bfloat16_t(int a) + { + *this = bfloat16_t(float(a)); + } + + /// Default constructor + bfloat16_t() = default; + + /// Constructor from float + __host__ __device__ __forceinline__ + bfloat16_t(float a) + { + // Refrence: + // https://github.com/pytorch/pytorch/blob/44cc873fba5e5ffc4d4d4eef3bd370b653ce1ce1/c10/util/BFloat16.h#L51 + uint16_t ir; + if (a != a) { + ir = UINT16_C(0x7FFF); + } else { + union { + uint32_t U32; + float F32; + }; + + F32 = a; + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + ir = static_cast((U32 + rounding_bias) >> 16); + } + this->__x = ir; + } + + /// Cast to __nv_bfloat16 + __host__ __device__ __forceinline__ + operator __nv_bfloat16() const + { + return reinterpret_cast(__x); + } + + /// Cast to float + __host__ __device__ __forceinline__ + operator float() const + { + float f = 0; + uint32_t *p = reinterpret_cast(&f); + *p = uint32_t(__x) << 16; + return f; + } + + + /// Get raw storage + __host__ __device__ __forceinline__ + uint16_t raw() + { + return this->__x; + } + + /// Equality + __host__ __device__ __forceinline__ + bool operator ==(const bfloat16_t &other) + { + return (this->__x == other.__x); + } + + /// Inequality + __host__ __device__ __forceinline__ + bool operator !=(const bfloat16_t &other) + { + return (this->__x != other.__x); + } + + /// Assignment by sum + __host__ __device__ __forceinline__ + bfloat16_t& operator +=(const bfloat16_t &rhs) + { + *this = bfloat16_t(float(*this) + float(rhs)); + return *this; + } + + /// Multiply + __host__ __device__ __forceinline__ + bfloat16_t operator*(const bfloat16_t &other) + { + return bfloat16_t(float(*this) * float(other)); + } + + /// Add + __host__ __device__ __forceinline__ + bfloat16_t operator+(const bfloat16_t &other) + { + return bfloat16_t(float(*this) + float(other)); + } + + /// Less-than + __host__ __device__ __forceinline__ + bool operator<(const bfloat16_t &other) const + { + return float(*this) < float(other); + } + + /// Less-than-equal + __host__ __device__ __forceinline__ + bool operator<=(const bfloat16_t &other) const + { + return float(*this) <= float(other); + } + + /// Greater-than + __host__ __device__ __forceinline__ + bool operator>(const bfloat16_t &other) const + { + return float(*this) > float(other); + } + + /// Greater-than-equal + __host__ __device__ __forceinline__ + bool operator>=(const bfloat16_t &other) const + { + return float(*this) >= float(other); + } + + /// numeric_traits::max + __host__ __device__ __forceinline__ + static bfloat16_t max() { + uint16_t max_word = 0x7F7F; + return reinterpret_cast(max_word); + } + + /// numeric_traits::lowest + __host__ __device__ __forceinline__ + static bfloat16_t lowest() { + uint16_t lowest_word = 0xFF7F; + return reinterpret_cast(lowest_word); + } +}; + + +/****************************************************************************** + * I/O stream overloads + ******************************************************************************/ + +/// Insert formatted \p bfloat16_t into the output stream +std::ostream& operator<<(std::ostream &out, const bfloat16_t &x) +{ + out << (float)x; + return out; +} + + +/// Insert formatted \p __nv_bfloat16 into the output stream +std::ostream& operator<<(std::ostream &out, const __nv_bfloat16 &x) +{ + return out << bfloat16_t(x); +} + + +/****************************************************************************** + * Traits overloads + ******************************************************************************/ + +template <> +struct cub::FpLimits +{ + static __host__ __device__ __forceinline__ bfloat16_t Max() { return bfloat16_t::max(); } + + static __host__ __device__ __forceinline__ bfloat16_t Lowest() { return bfloat16_t::lowest(); } +}; + +template <> struct cub::NumericTraits : cub::BaseTraits {}; + + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif diff --git a/test/test_device_radix_sort.cu b/test/test_device_radix_sort.cu index 1bc5b13581..c7652383b8 100644 --- a/test/test_device_radix_sort.cu +++ b/test/test_device_radix_sort.cu @@ -41,6 +41,10 @@ #include #endif +#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__ + #include +#endif + #include #include #include @@ -682,6 +686,24 @@ void InitializeSolution( // Test generation //--------------------------------------------------------------------- +template +struct UnwrapHalfAndBfloat16 { + using Type = T; +}; + +#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__ +template <> +struct UnwrapHalfAndBfloat16 { + using Type = __half; +}; +#endif + +#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__ +template <> +struct UnwrapHalfAndBfloat16 { + using Type = __nv_bfloat16; +}; +#endif /** * Test DeviceRadixSort @@ -703,11 +725,7 @@ void Test( ValueT *h_reference_values) { // Key alias type -#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__ - typedef typename If::VALUE, __half, KeyT>::Type KeyAliasT; -#else - typedef KeyT KeyAliasT; -#endif + using KeyAliasT = typename UnwrapHalfAndBfloat16::Type; const bool KEYS_ONLY = Equals::VALUE; @@ -1222,6 +1240,9 @@ int main(int argc, char** argv) #if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__ Test(num_items, 1, RANDOM, entropy_reduction, 0, bits); +#endif +#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__ + Test(num_items, 1, RANDOM, entropy_reduction, 0, bits); #endif Test(num_items, 1, RANDOM, entropy_reduction, 0, bits); Test(num_items, 1, RANDOM, entropy_reduction, 0, bits); @@ -1280,7 +1301,10 @@ int main(int argc, char** argv) TestGen (num_items, num_segments); #if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__ - TestGen (num_items, num_segments); + TestGen (num_items, num_segments); +#endif +#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__ + TestGen (num_items, num_segments); #endif TestGen (num_items, num_segments); diff --git a/test/test_util.h b/test/test_util.h index 98d10e786f..d3b871f2fa 100644 --- a/test/test_util.h +++ b/test/test_util.h @@ -49,6 +49,7 @@ #include "mersenne.h" #include "half.h" +#include "bfloat16.h" #include "cub/util_debug.cuh" #include "cub/util_device.cuh" @@ -406,7 +407,15 @@ __noinline__ bool IsNaN(half_t val) ((bits >= 0xFC01) /*&& (bits <= 0xFFFFFFFF)*/)); } +template<> +__noinline__ bool IsNaN(bfloat16_t val) +{ + const auto bits = SafeBitCast(val); + // commented bit is always true, leaving for documentation: + return (((bits >= 0x7F81) && (bits <= 0x7FFF)) || + ((bits >= 0xFF81) /*&& (bits <= 0xFFFFFFFF)*/)); +} /** * Generates random keys.