Skip to content

Commit

Permalink
Support both Float16 ABIs depending on LLVM and platform (#49527)
Browse files Browse the repository at this point in the history
There are two Float16 ABIs in the wild, one for platforms that have a
defing register and the original one where we used i16.

LLVM 15 follows GCC and uses the new ABI on x86/ARM but not PPC.

Co-authored-by: Gabriel Baraldi <baraldigabriel@gmail.com>
  • Loading branch information
vchuravy and gbaraldi authored Apr 27, 2023
1 parent 09a0f34 commit 959902f
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ static void reportWriterError(const ErrorInfoBase &E)
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
}

#if JULIA_FLOAT16_ABI == 1
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
{
Function *target = M.getFunction(alias);
Expand All @@ -510,7 +511,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
auto val = builder.CreateCall(target, CallArgs);
builder.CreateRet(val);
}

#endif
void multiversioning_preannotate(Module &M);

// See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t.
Expand Down Expand Up @@ -943,6 +944,8 @@ struct ShardTimers {
}
};

void emitFloat16Wrappers(Module &M, bool external);

// Perform the actual optimization and emission of the output files
static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *outputs, const std::string *names,
NewArchiveMember *unopt, NewArchiveMember *opt, NewArchiveMember *obj, NewArchiveMember *asm_,
Expand Down Expand Up @@ -1003,7 +1006,9 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
}
}
// no need to inject aliases if we have no functions

if (inject_aliases) {
#if JULIA_FLOAT16_ABI == 1
// We would like to emit an alias or an weakref alias to redirect these symbols
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
// So for now we inject a definition of these functions that calls our runtime
Expand All @@ -1018,8 +1023,10 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2",
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false));
#else
emitFloat16Wrappers(M, false);
#endif
}

timers.optimize.stopTimer();

if (opt) {
Expand Down
56 changes: 56 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5818,6 +5818,7 @@ static void emit_cfunc_invalidate(
prepare_call_in(gf_thunk->getParent(), jlapplygeneric_func));
}

#include <iostream>
static Function* gen_cfun_wrapper(
Module *into, jl_codegen_params_t &params,
const function_sig_t &sig, jl_value_t *ff, const char *aliasname,
Expand Down Expand Up @@ -8704,6 +8705,58 @@ static JuliaVariable *julia_const_gv(jl_value_t *val)
return nullptr;
}

// Handle FLOAT16 ABI v2
#if JULIA_FLOAT16_ABI == 2
static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external)
{
Function *calledFun = M.getFunction(calledName);
if (!calledFun) {
calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M);
}
auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage;
auto wrapperFun = Function::Create(FTwrapper, linkage, wrapperName, M);
wrapperFun->addFnAttr(Attribute::AlwaysInline);
llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun));
SmallVector<Value *, 4> CallArgs;
if (wrapperFun->arg_size() != calledFun->arg_size()){
llvm::errs() << "FATAL ERROR: Can't match wrapper to called function";
abort();
}
for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin();
wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg)
{
CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType()));
}
auto val = builder.CreateCall(calledFun, CallArgs);
auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType());
builder.CreateRet(retval);
}

void emitFloat16Wrappers(Module &M, bool external)
{
auto &ctx = M.getContext();
makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false),
FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false), external);
}

static void init_f16_funcs(void)
{
auto ctx = jl_ExecutionEngine->acquireContext();
auto TSM = jl_create_ts_module("F16Wrappers", ctx, imaging_default());
auto aliasM = TSM.getModuleUnlocked();
emitFloat16Wrappers(*aliasM, true);
jl_ExecutionEngine->addModule(std::move(TSM));
}
#endif

static void init_jit_functions(void)
{
add_named_global(jlstack_chk_guard_var, &__stack_chk_guard);
Expand Down Expand Up @@ -8942,6 +8995,9 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void)
jl_init_llvm();
// Now that the execution engine exists, initialize all modules
init_jit_functions();
#if JULIA_FLOAT16_ABI == 2
init_f16_funcs();
#endif
}

extern "C" JL_DLLEXPORT void jl_teardown_codegen_impl() JL_NOTSAFEPOINT
Expand Down
2 changes: 2 additions & 0 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ JuliaOJIT::JuliaOJIT()

JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

#if JULIA_FLOAT16_ABI == 1
orc::SymbolAliasMap jl_crt = {
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
Expand All @@ -1391,6 +1392,7 @@ JuliaOJIT::JuliaOJIT()
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
};
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
#endif

#ifdef MSAN_EMUTLS_WORKAROUND
orc::SymbolMap msan_crt;
Expand Down
10 changes: 10 additions & 0 deletions src/llvm-version.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <llvm/Config/llvm-config.h>
#include "julia_assert.h"
#include "platform.h"

// The LLVM version used, JL_LLVM_VERSION, is represented as a 5-digit integer
// of the form ABBCC, where A is the major version, B is minor, and C is patch.
Expand All @@ -17,6 +18,15 @@
#define JL_LLVM_OPAQUE_POINTERS 1
#endif

// Pre GCC 12 libgcc defined the ABI for Float16->Float32
// to take an i16. GCC 12 silently changed the ABI to now pass
// Float16 in Float32 registers.
#if JL_LLVM_VERSION < 150000 || defined(_CPU_PPC64_) || defined(_CPU_PPC_)
#define JULIA_FLOAT16_ABI 1
#else
#define JULIA_FLOAT16_ABI 2
#endif

#ifdef __cplusplus
#if defined(__GNUC__) && (__GNUC__ >= 9)
// Added in GCC 9, this warning is annoying
Expand Down

0 comments on commit 959902f

Please sign in to comment.