Skip to content

Commit

Permalink
[CUDA] BF16 support (apache#7014)
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 authored and Trevor Morris committed May 6, 2021
1 parent 20ab3e0 commit 125e101
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 7 deletions.
9 changes: 8 additions & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,19 @@ class DataType {
*/
static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); }
/*!
* \brief Construct an uint type.
* \brief Construct an float type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
/*!
* \brief Construct an bfloat type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
Expand Down
16 changes: 15 additions & 1 deletion python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,22 @@ def have_tensorcore(compute_version=None, target=None):
major, minor = compute_version.split("_")[1]
compute_version = major + "." + minor
major, _ = parse_compute_version(compute_version)
if major >= 7:
return True

return False


def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not
if major == 7:
Parameters
----------
compute_version : str
compute capability of a GPU (e.g. "8.0")
"""
major, _ = parse_compute_version(compute_version)
if major >= 8:
return True

return False
4 changes: 3 additions & 1 deletion python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def copyfrom(self, source_array):
source_array.shape, shape
)
)
source_array = np.ascontiguousarray(source_array, dtype=dtype)
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
Expand Down
72 changes: 70 additions & 2 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ std::string CodeGenCUDA::Finish() {
decl_stream << _cuda_half_util;
}

if (enable_bf16_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n";
decl_stream << "#include <cuda_bf16.h>\n";
decl_stream << "__device__ nv_bfloat16 max"
<< "(nv_bfloat16 a, nv_bfloat16 b)\n"
<< "{\n return __hgt(a, b) ? a : b;\n}\n";
decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n"
<< "{\n return __hlt(a, b) ? a : b;\n}\n";
decl_stream << "#endif\n\n";
decl_stream << _cuda_bfloat16_util;
}

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
}
Expand Down Expand Up @@ -170,6 +182,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
os << lanes;
return;
}
} else if (t.is_bfloat16()) {
enable_bf16_ = true;
if (t.is_scalar()) {
os << "nv_bfloat16";
} else if (lanes <= 8) {
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2;
} else {
fail = true;
}
if (!fail) return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
Expand Down Expand Up @@ -382,6 +405,8 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
Expand Down Expand Up @@ -427,6 +452,9 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
} else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
Expand Down Expand Up @@ -687,7 +715,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1))
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
op->dtype == DataType::BFloat(16))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
Expand Down Expand Up @@ -767,6 +796,19 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
return;
}

if (op->dtype.is_bfloat16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
}
os << ')';
return;
}

std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
Expand Down Expand Up @@ -836,6 +878,13 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
}

inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
// Type code is kBFloat
if (op->dtype.is_bfloat16()) {
os << "__float2bfloat16_rn";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kFloat
switch (op->dtype.bits()) {
case 64:
case 32: {
Expand Down Expand Up @@ -938,7 +987,7 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode*
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
if (op->dtype.is_float16() && IsVolatile(op->buffer_var.get())) {
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer_var.get())) {
os << "(";
PrintType(op->dtype, os);
os << ")(" << value << ")";
Expand Down Expand Up @@ -979,6 +1028,25 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
return;
}

if (t.is_bfloat16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_bfloat162(" << value;
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
}
return;
}

if (i == 0) {
os << "make_";
PrintType(t, os);
Expand Down
4 changes: 3 additions & 1 deletion src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CodeGenCUDA final : public CodeGenC {
void Init(bool output_ssa);
std::string Finish();
bool need_include_path() {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
}
// override behavior
void PrintFuncPrefix() final;
Expand Down Expand Up @@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC {
std::string vid_global_barrier_expect_;
// whether enable fp16
bool enable_fp16_{false};
// whether enable bf16
bool enable_bf16_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ struct CUDAMath {
default:
return "";
}
} else if (t.is_bfloat16()) {
return 'h' + name;
}
return "";
}
Expand Down
24 changes: 24 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,30 @@ static inline __device__ __host__ half htanh(half x) {
#endif
)";

static constexpr const char* _cuda_bfloat16_util = R"(
// Pack two bfloat16 values.
static inline __device__ __host__ unsigned
__pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// fix undefined fp16 match function
static inline __device__ __host__ nv_bfloat16 hpow(nv_bfloat16 x, nv_bfloat16 y) {
float tmp_x = __bfloat162float(x);
float tmp_y = __bfloat162float(y);
float result = powf(tmp_x, tmp_y);
return __float2bfloat16(result);
}
static inline __device__ __host__ nv_bfloat16 htanh(nv_bfloat16 x) {
float tmp_x = __bfloat162float(x);
float result = tanhf(tmp_x);
return __float2bfloat16(result);
}
)";

static constexpr const char* _cuda_warp_intrinsic_util = R"(
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
#define __shfl_sync(mask, var, lane, width) \
Expand Down
50 changes: 49 additions & 1 deletion tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
from tvm import topi
import unittest
from tvm.contrib.nvcc import have_fp16, have_int8
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing

Expand Down Expand Up @@ -67,6 +67,53 @@ def check_cuda(dtype, n, lanes):
check_cuda("float16", 64, 8)


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_bf16_vectorize_add():
if not have_bf16(tvm.gpu(0).compute_version):
print("skip because gpu does not support bf16")
return
num_thread = 8

def np_float2np_bf16(arr):
"""Convert a numpy array of float to a numpy array
of bf16 in uint16"""
orig = arr.view("<u4")
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
return np.right_shift(orig + bias, 16).astype("uint16")

def np_bf162np_float(arr):
"""Convert a numpy array of bf16 (uint16) to a numpy array
of float"""
u32 = np.left_shift(arr.astype("uint32"), 16)
return u32.view("<f4")

def check_cuda(n, lanes):
A = te.placeholder((n,), name="A", dtype="bfloat16x%d" % lanes)
B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B")
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, bx)
s[B].bind(xi, tx)
with tvm.transform.PassContext(
disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"]
):
fun = tvm.build(s, [A, B], "cuda")
ctx = tvm.gpu(0)
np_a = np.random.uniform(size=(n, lanes)).astype("float32")
np_a = np_bf162np_float(np_float2np_bf16(np_a))
a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np_float2np_bf16(np_a))
c = tvm.nd.empty((n,), B.dtype, ctx)
fun(a, c)
c = tvm.nd.empty((n, lanes), "uint16", ctx).copyfrom(c)
tvm.testing.assert_allclose(c.asnumpy(), np_float2np_bf16(np_a + 1))

check_cuda(64, 2)
check_cuda(64, 4)
check_cuda(64, 6)
check_cuda(64, 8)


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_multiply_add():
Expand Down Expand Up @@ -922,6 +969,7 @@ def test_unrolled_vectorization():

if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_bf16_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int8()
Expand Down

0 comments on commit 125e101

Please sign in to comment.