Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Add BFloat16 support for radix sort
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed May 17, 2021
1 parent a8910ac commit ce68c89
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 6 deletions.
21 changes: 21 additions & 0 deletions cub/util_type.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__
#include <cuda_fp16.h>
#endif
#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__
#include <cuda_bf16.h>
#endif

#include "util_macro.cuh"
#include "util_arch.cuh"
Expand Down Expand Up @@ -1079,6 +1082,21 @@ struct FpLimits<__half>
};
#endif

#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__
template <>
struct FpLimits<__nv_bfloat16>
{
static __host__ __device__ __forceinline__ __nv_bfloat16 Max() {
unsigned short max_word = 0x7F7F;
return reinterpret_cast<__nv_bfloat16&>(max_word);
}

static __host__ __device__ __forceinline__ __nv_bfloat16 Lowest() {
unsigned short lowest_word = 0xFF7F;
return reinterpret_cast<__nv_bfloat16&>(lowest_word);
}
};
#endif

/**
* Basic type traits (fp primitive specialization)
Expand Down Expand Up @@ -1146,6 +1164,9 @@ template <> struct NumericTraits<double> : BaseTraits<FLOATING_POIN
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__
template <> struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {};
#endif
#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__
template <> struct NumericTraits<__nv_bfloat16> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __nv_bfloat16> {};
#endif

template <> struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {};

Expand Down
239 changes: 239 additions & 0 deletions test/bfloat16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/******************************************************************************
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/

#pragma once

/**
* \file
* Utilities for interacting with the opaque CUDA __nv_bfloat16 type
*/

#include <stdint.h>
#include <cuda_bf16.h>
#include <iosfwd>

#include <cub/util_type.cuh>

#ifdef __GNUC__
// There's a ton of type-punning going on in this file.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif


/******************************************************************************
* bfloat16_t
******************************************************************************/

/**
* Host-based fp16 data type compatible and convertible with __nv_bfloat16
*/
struct bfloat16_t
{
uint16_t __x;

/// Constructor from __nv_bfloat16
__host__ __device__ __forceinline__
bfloat16_t(const __nv_bfloat16 &other)
{
__x = reinterpret_cast<const uint16_t&>(other);
}

/// Constructor from integer
__host__ __device__ __forceinline__
bfloat16_t(int a)
{
*this = bfloat16_t(float(a));
}

/// Default constructor
bfloat16_t() = default;

/// Constructor from float
__host__ __device__ __forceinline__
bfloat16_t(float a)
{
// Refrence:
// https://github.com/pytorch/pytorch/blob/44cc873fba5e5ffc4d4d4eef3bd370b653ce1ce1/c10/util/BFloat16.h#L51
uint16_t ir;
if (a != a) {
ir = UINT16_C(0x7FFF);
} else {
union {
uint32_t U32;
float F32;
};

F32 = a;
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
ir = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
this->__x = ir;
}

/// Cast to __nv_bfloat16
__host__ __device__ __forceinline__
operator __nv_bfloat16() const
{
return reinterpret_cast<const __nv_bfloat16&>(__x);
}

/// Cast to float
__host__ __device__ __forceinline__
operator float() const
{
float f = 0;
uint32_t *p = reinterpret_cast<uint32_t *>(&f);
*p = uint32_t(__x) << 16;
return f;
}


/// Get raw storage
__host__ __device__ __forceinline__
uint16_t raw()
{
return this->__x;
}

/// Equality
__host__ __device__ __forceinline__
bool operator ==(const bfloat16_t &other)
{
return (this->__x == other.__x);
}

/// Inequality
__host__ __device__ __forceinline__
bool operator !=(const bfloat16_t &other)
{
return (this->__x != other.__x);
}

/// Assignment by sum
__host__ __device__ __forceinline__
bfloat16_t& operator +=(const bfloat16_t &rhs)
{
*this = bfloat16_t(float(*this) + float(rhs));
return *this;
}

/// Multiply
__host__ __device__ __forceinline__
bfloat16_t operator*(const bfloat16_t &other)
{
return bfloat16_t(float(*this) * float(other));
}

/// Add
__host__ __device__ __forceinline__
bfloat16_t operator+(const bfloat16_t &other)
{
return bfloat16_t(float(*this) + float(other));
}

/// Less-than
__host__ __device__ __forceinline__
bool operator<(const bfloat16_t &other) const
{
return float(*this) < float(other);
}

/// Less-than-equal
__host__ __device__ __forceinline__
bool operator<=(const bfloat16_t &other) const
{
return float(*this) <= float(other);
}

/// Greater-than
__host__ __device__ __forceinline__
bool operator>(const bfloat16_t &other) const
{
return float(*this) > float(other);
}

/// Greater-than-equal
__host__ __device__ __forceinline__
bool operator>=(const bfloat16_t &other) const
{
return float(*this) >= float(other);
}

/// numeric_traits<bfloat16_t>::max
__host__ __device__ __forceinline__
static bfloat16_t max() {
uint16_t max_word = 0x7F7F;
return reinterpret_cast<bfloat16_t&>(max_word);
}

/// numeric_traits<bfloat16_t>::lowest
__host__ __device__ __forceinline__
static bfloat16_t lowest() {
uint16_t lowest_word = 0xFF7F;
return reinterpret_cast<bfloat16_t&>(lowest_word);
}
};


/******************************************************************************
* I/O stream overloads
******************************************************************************/

/// Insert formatted \p bfloat16_t into the output stream
std::ostream& operator<<(std::ostream &out, const bfloat16_t &x)
{
out << (float)x;
return out;
}


/// Insert formatted \p __nv_bfloat16 into the output stream
std::ostream& operator<<(std::ostream &out, const __nv_bfloat16 &x)
{
return out << bfloat16_t(x);
}


/******************************************************************************
* Traits overloads
******************************************************************************/

template <>
struct cub::FpLimits<bfloat16_t>
{
static __host__ __device__ __forceinline__ bfloat16_t Max() { return bfloat16_t::max(); }

static __host__ __device__ __forceinline__ bfloat16_t Lowest() { return bfloat16_t::lowest(); }
};

template <> struct cub::NumericTraits<bfloat16_t> : cub::BaseTraits<FLOATING_POINT, true, false, unsigned short, bfloat16_t> {};


#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
36 changes: 30 additions & 6 deletions test/test_device_radix_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
#include <cuda_fp16.h>
#endif

#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__
#include <cuda_bf16.h>
#endif

#include <cub/util_allocator.cuh>
#include <cub/util_math.cuh>
#include <cub/device/device_radix_sort.cuh>
Expand Down Expand Up @@ -682,6 +686,24 @@ void InitializeSolution(
// Test generation
//---------------------------------------------------------------------

template <typename T>
struct UnwrapHalfAndBfloat16 {
using Type = T;
};

#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__
template <>
struct UnwrapHalfAndBfloat16<half_t> {
using Type = __half;
};
#endif

#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__
template <>
struct UnwrapHalfAndBfloat16<bfloat16_t> {
using Type = __nv_bfloat16;
};
#endif

/**
* Test DeviceRadixSort
Expand All @@ -703,11 +725,7 @@ void Test(
ValueT *h_reference_values)
{
// Key alias type
#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__
typedef typename If<Equals<KeyT, half_t>::VALUE, __half, KeyT>::Type KeyAliasT;
#else
typedef KeyT KeyAliasT;
#endif
using KeyAliasT = typename UnwrapHalfAndBfloat16<KeyT>::Type;

const bool KEYS_ONLY = Equals<ValueT, NullType>::VALUE;

Expand Down Expand Up @@ -1222,6 +1240,9 @@ int main(int argc, char** argv)

#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__
Test<CUB, half_t, NullType, IS_DESCENDING>(num_items, 1, RANDOM, entropy_reduction, 0, bits);
#endif
#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__
Test<CUB, bfloat16_t, NullType, IS_DESCENDING>(num_items, 1, RANDOM, entropy_reduction, 0, bits);
#endif
Test<CUB, float, NullType, IS_DESCENDING>(num_items, 1, RANDOM, entropy_reduction, 0, bits);
Test<CUB, double, NullType, IS_DESCENDING>(num_items, 1, RANDOM, entropy_reduction, 0, bits);
Expand Down Expand Up @@ -1280,7 +1301,10 @@ int main(int argc, char** argv)
TestGen<unsigned long long> (num_items, num_segments);

#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000) && !__NVCOMPILER_CUDA__
TestGen<half_t> (num_items, num_segments);
TestGen<half_t> (num_items, num_segments);
#endif
#if (__CUDACC_VER_MAJOR__ >= 11 || CUDA_VERSION >= 11000) && !__NVCOMPILER_CUDA__
TestGen<bfloat16_t> (num_items, num_segments);
#endif
TestGen<float> (num_items, num_segments);

Expand Down
9 changes: 9 additions & 0 deletions test/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

#include "mersenne.h"
#include "half.h"
#include "bfloat16.h"

#include "cub/util_debug.cuh"
#include "cub/util_device.cuh"
Expand Down Expand Up @@ -406,7 +407,15 @@ __noinline__ bool IsNaN<half_t>(half_t val)
((bits >= 0xFC01) /*&& (bits <= 0xFFFFFFFF)*/));
}

template<>
__noinline__ bool IsNaN<bfloat16_t>(bfloat16_t val)
{
const auto bits = SafeBitCast<unsigned short>(val);

// commented bit is always true, leaving for documentation:
return (((bits >= 0x7F81) && (bits <= 0x7FFF)) ||
((bits >= 0xFF81) /*&& (bits <= 0xFFFFFFFF)*/));
}

/**
* Generates random keys.
Expand Down

0 comments on commit ce68c89

Please sign in to comment.