diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 6b69f97a66fe5..a4f9cf8a3c480 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -4,6 +4,7 @@ */ #ifdef TVM_LLVM_VERSION // Part of the code are adapted from Halide's CodeGen_LLVM +#include #include #include @@ -410,12 +411,16 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { int num_elems = static_cast(vec->getType()->getVectorNumElements()); if (extent == num_elems && begin == 0) return vec; - CHECK_LE(begin + extent, num_elems); - std::vector indices; + std::vector indices; + indices.reserve(extent); for (int i = 0; i < extent; ++i) { - indices.push_back(begin + i); + if (begin + i >= 0 && begin + i < num_elems) { + indices.push_back(llvm::ConstantInt::get(t_int32_, begin + i)); + } else { + indices.push_back(llvm::UndefValue::get(t_int32_)); + } } - return builder_->CreateShuffleVector(vec, vec, indices); + return builder_->CreateShuffleVector(vec, vec, llvm::ConstantVector::get(indices)); } llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { @@ -446,24 +451,31 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { v->getType()->getVectorNumElements()); } while (vecs.size() > 1) { - for (size_t i = 0; i < vecs.size(); i+=2) { - if (i + 1 >= vecs.size()) { - vecs[i / 2] = vecs[i]; continue; - } + std::vector new_vecs; + for (size_t i = 0; i < vecs.size() - 1; i += 2) { llvm::Value* lhs = vecs[i]; llvm::Value* rhs = vecs[i + 1]; - int lanes = static_cast(std::max( - lhs->getType()->getVectorNumElements(), - rhs->getType()->getVectorNumElements())); - lhs = CreateVecPad(lhs, lanes); - rhs = CreateVecPad(lhs, lanes); + const auto lhs_lanes = lhs->getType()->getVectorNumElements(); + const auto rhs_lanes = rhs->getType()->getVectorNumElements(); + if (lhs_lanes < rhs_lanes) { + lhs = CreateVecPad(lhs, rhs_lanes); + } else if (rhs_lanes < lhs_lanes) { + rhs = CreateVecPad(rhs, lhs_lanes); + } + const auto shared_lanes = std::max(lhs_lanes, rhs_lanes); std::vector mask; - for (int i = 0; i < lanes * 2; ++i) { + for (int i = 0; i < lhs_lanes; ++i) { mask.push_back(i); } - vecs[i / 2] = builder_->CreateShuffleVector(lhs, rhs, mask); + for (int i = 0; i < rhs_lanes; ++i) { + mask.push_back(shared_lanes + i); + } + new_vecs.push_back(builder_->CreateShuffleVector(lhs, rhs, mask)); + } + if (vecs.size() % 2 != 0) { + new_vecs.push_back(vecs.back()); } - vecs.resize((vecs.size() + 1) / 2); + vecs.swap(new_vecs); } return CreateVecSlice(vecs[0], 0, total_lanes); } diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc new file mode 100644 index 0000000000000..d5a9a8fa30be1 --- /dev/null +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -0,0 +1,104 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file codegen_x86_64.cc + * \brief X86-64 specific code generator + */ +#ifdef TVM_LLVM_VERSION +#include "codegen_cpu.h" + +#include "llvm/MC/MCSubtargetInfo.h" + +namespace tvm { +namespace codegen { + +namespace { +bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) { + const auto* MCInfo = tm.getMCSubtargetInfo(); + return MCInfo->checkFeatures(std::string("+") + feature); +} +} // namespace + +class CodeGenX86_64 final : public CodeGenCPU { + public: + llvm::Value* VisitExpr_(const Cast* op) override; + + private: + llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty, + const std::vector& args); +}; + +llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) { + // LLVM does not automatically generate the correct instruction sequences for + // half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of + // vcvtph2ps), so we explicitly generate them ourselves. + const auto from = op->value.type(); + const auto to = op->type; + if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { + CHECK_EQ(from.lanes(), to.lanes()); + CHECK_NOTNULL(target_machine_); + + const auto has_f16c = TargetHasFeature(*target_machine_, "f16c"); + const auto has_avx512 = TargetHasFeature(*target_machine_, "avx512f"); + + if (from.lanes() >= 16 && has_avx512) { + return CallVectorIntrin( + ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, LLVMType(Float(32, from.lanes())), + { + MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value}, + ir::Call::PureIntrinsic)), + MakeValue(ir::Broadcast::make(ir::FloatImm::make(Float(32), 0), from.lanes())), + /*mask=*/MakeValue(ir::IntImm::make(Int(16), -1)), + /*rounding-mode=*/MakeValue(ir::IntImm::make(Int(32), 4)), + }); + } + + if (from.lanes() >= 8 && has_f16c) { + return CallVectorIntrin( + ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(Float(32, from.lanes())), + {MakeValue(ir::Call::make(Int(16, from.lanes()), ir::Call::reinterpret, {op->value}, + ir::Call::PureIntrinsic))}); + } + } + + return CodeGenCPU::VisitExpr_(op); +} + +llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, + llvm::Type* result_ty, + + const std::vector& args) { + llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {}); + if (intrin_lanes == result_ty->getVectorNumElements()) { + return builder_->CreateCall(f, args); + } + + // Otherwise, we split the vector into intrin_lanes sized elements (widening where necessary), + // compute each result, and then concatenate the vectors (slicing the result if necessary). + CHECK_LT(intrin_lanes, result_ty->getVectorNumElements()); + std::vector split_results; + for (auto i = 0; i < result_ty->getVectorNumElements(); i += intrin_lanes) { + std::vector split_args; + for (const auto& v : args) { + if (v->getType()->isVectorTy()) { + CHECK_EQ(v->getType()->getVectorNumElements(), result_ty->getVectorNumElements()); + split_args.push_back(CreateVecSlice(v, i, intrin_lanes)); + } else { + split_args.push_back(v); + } + } + split_results.push_back(CallVectorIntrin( + id, intrin_lanes, llvm::VectorType::get(result_ty->getScalarType(), intrin_lanes), + split_args)); + } + return CreateVecSlice(CreateVecConcat(split_results), 0, result_ty->getVectorNumElements()); +} + +TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") +.set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenX86_64(); + *rv = static_cast(cg); + }); + +} // namespace codegen +} // namespace tvm +#endif // TVM_LLVM_VERSION diff --git a/tests/python/unittest/test_codegen_x86.py b/tests/python/unittest/test_codegen_x86.py new file mode 100644 index 0000000000000..0281a4c6b8946 --- /dev/null +++ b/tests/python/unittest/test_codegen_x86.py @@ -0,0 +1,55 @@ +import tvm +import re + + +def test_fp16_to_fp32(): + def fp16_to_fp32(target, width, match=None, not_match=None): + elements = 64 + n = tvm.convert(elements) + A = tvm.placeholder((n, width), dtype="float16", name='A') + B = tvm.compute(A.shape, lambda *i: A(*i).astype("float32"), name='B') + s = tvm.create_schedule(B.op) + s[B].vectorize(s[B].op.axis[1]) + f = tvm.build(s, [A, B], target) + + assembly = f.get_source('asm').splitlines() + if match: + matches = [l for l in assembly if re.search(match, l)] + assert matches + if not_match: + not_matches = [l for l in assembly if re.search(not_match, l)] + assert not not_matches + + + fp16_to_fp32( + 'llvm -mcpu=skylake-avx512', 15, + match="vcvtph2ps.*ymm", not_match="vcvtph2ps.*zmm") + fp16_to_fp32( + 'llvm -mcpu=skylake-avx512', 16, + match="vcvtph2ps.*zmm") + fp16_to_fp32( + 'llvm -mcpu=skylake-avx512', 17, + match="vcvtph2ps.*zmm") + fp16_to_fp32( + 'llvm -mcpu=skylake-avx512', 49, + match="vcvtph2ps.*zmm") + fp16_to_fp32( + 'llvm -mcpu=skylake-avx512 -mattr=-avx512f', 49, + match="vcvtph2ps.*ymm", + not_match="vcvtph2ps.*zmm") + fp16_to_fp32( + 'llvm -mcpu=skylake-avx512 -mattr=-f16c,-avx512f', 49, + not_match="vcvtph2ps") + fp16_to_fp32( + 'llvm -mcpu=core-avx2', 8, + match="vcvtph2ps.*ymm") + fp16_to_fp32( + 'llvm -mcpu=core-avx2', 9, + match="vcvtph2ps.*ymm") + fp16_to_fp32( + 'llvm', 9, + not_match="vcvtph2ps") + + +if __name__ == "__main__": + test_fp16_to_fp32()