Skip to content

Commit

Permalink
Sugar for virtual functions
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 2, 2021
1 parent 8f0a66a commit dbb6552
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 75 deletions.
54 changes: 51 additions & 3 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ class Enzyme : public ModulePass {
Fn->getName() == "__enzyme_double" ||
Fn->getName() == "__enzyme_integer" ||
Fn->getName() == "__enzyme_pointer" ||
Fn->getName().contains("__enzyme_virtualreverse") ||
Fn->getName().contains("__enzyme_call_inactive") ||
Fn->getName().contains("__enzyme_autodiff") ||
Fn->getName().contains("__enzyme_fwddiff") ||
Expand Down Expand Up @@ -763,6 +764,7 @@ class Enzyme : public ModulePass {
}

std::map<CallInst *, DerivativeMode> toLower;
std::map<CallInst *, DerivativeMode> toVirtual;
std::set<CallInst *> InactiveCalls;
std::set<CallInst *> IterCalls;
retry:;
Expand Down Expand Up @@ -824,6 +826,10 @@ class Enzyme : public ModulePass {
}
}
}
if (Fn->getName() == "__enzyme_virtualreverse") {
Fn->addFnAttr(Attribute::ReadNone);
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
}
if (Fn->getName() == "__enzyme_iter") {
Fn->addFnAttr(Attribute::ReadNone);
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
Expand Down Expand Up @@ -960,6 +966,7 @@ class Enzyme : public ModulePass {
}

bool enableEnzyme = false;
bool virtualCall = false;
DerivativeMode mode;
if (Fn->getName().contains("__enzyme_autodiff")) {
enableEnzyme = true;
Expand All @@ -973,10 +980,13 @@ class Enzyme : public ModulePass {
} else if (Fn->getName().contains("__enzyme_reverse")) {
enableEnzyme = true;
mode = DerivativeMode::ReverseModeGradient;
} else if (Fn->getName().contains("__enzyme_virtualreverse")) {
enableEnzyme = true;
virtualCall = true;
mode = DerivativeMode::ReverseModeCombined;
}

if (enableEnzyme) {
toLower[CI] = mode;

Value *fn = CI->getArgOperand(0);
while (auto ci = dyn_cast<CastInst>(fn)) {
Expand Down Expand Up @@ -1013,9 +1023,16 @@ class Enzyme : public ModulePass {
}
goto retry;
}
if (auto dc = dyn_cast<Function>(fn))

if (virtualCall)
toVirtual[CI] = mode;
else
toLower[CI] = mode;

if (auto dc = dyn_cast<Function>(fn)) {
Changed |=
lowerEnzymeCalls(*dc, /*PostOpt*/ true, successful, done);
}
}
}
}
Expand Down Expand Up @@ -1048,6 +1065,36 @@ class Enzyme : public ModulePass {
break;
}

for (auto pair : toVirtual) {
auto CI = pair.first;
Value *fn = CI->getArgOperand(0);
while (auto ci = dyn_cast<CastInst>(fn)) {
fn = ci->getOperand(0);
}
while (auto ci = dyn_cast<BlockAddress>(fn)) {
fn = ci->getFunction();
}
while (auto ci = dyn_cast<ConstantExpr>(fn)) {
fn = ci->getOperand(0);
}
auto F = cast<Function>(fn);
TypeAnalysis TA(TLI);

auto Arch =
llvm::Triple(
CI->getParent()->getParent()->getParent()->getTargetTriple())
.getArch();

bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
Arch == Triple::amdgcn;

auto val = GradientUtils::GetOrCreateShadowFunction(Logic, TLI, TA, F,
AtomicAdd, PostOpt);
CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType()));
CI->eraseFromParent();
Changed = true;
}

if (Changed) {
// TODO consider enabling when attributor does not delete
// dead internal functions, which invalidates Enzyme's cache
Expand Down Expand Up @@ -1199,7 +1246,8 @@ class Enzyme : public ModulePass {
for (Function &F : M) {
if (F.getName() == "__enzyme_float" || F.getName() == "__enzyme_double" ||
F.getName() == "__enzyme_integer" ||
F.getName() == "__enzyme_pointer") {
F.getName() == "__enzyme_pointer" ||
F.getName().contains("__enzyme_virtualreverse")) {
F.addFnAttr(Attribute::ReadNone);
for (auto &arg : F.args()) {
if (arg.getType()->isPointerTy()) {
Expand Down
151 changes: 79 additions & 72 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,84 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
return res;
}

Constant *GradientUtils::GetOrCreateShadowFunction(EnzymeLogic &Logic,
TargetLibraryInfo &TLI,
TypeAnalysis &TA,
Function *fn, bool AtomicAdd,
bool PostOpt) {
//! Todo allow tape propagation
// Note that specifically this should _not_ be called with topLevel=true
// (since it may not be valid to always assume we can recompute the
// augmented primal) However, in the absence of a way to pass tape data
// from an indirect augmented (and also since we dont presently allow
// indirect augmented calls), topLevel MUST be true otherwise subcalls will
// not be able to lookup the augmenteddata/subdata (triggering an assertion
// failure, among much worse)
std::map<Argument *, bool> uncacheable_args;
FnTypeInfo type_args(fn);

// conservatively assume that we can only cache existing floating types
// (i.e. that all args are uncacheable)
std::vector<DIFFE_TYPE> types;
for (auto &a : fn->args()) {
uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy();
type_args.Arguments.insert(std::pair<Argument *, TypeTree>(&a, {}));
type_args.KnownValues.insert(
std::pair<Argument *, std::set<int64_t>>(&a, {}));
DIFFE_TYPE typ;
if (a.getType()->isFPOrFPVectorTy()) {
typ = DIFFE_TYPE::OUT_DIFF;
} else if (a.getType()->isIntegerTy() &&
cast<IntegerType>(a.getType())->getBitWidth() < 16) {
typ = DIFFE_TYPE::CONSTANT;
} else if (a.getType()->isVoidTy() || a.getType()->isEmptyTy()) {
typ = DIFFE_TYPE::CONSTANT;
} else {
typ = DIFFE_TYPE::DUP_ARG;
}
types.push_back(typ);
}

DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy()
? DIFFE_TYPE::OUT_DIFF
: DIFFE_TYPE::DUP_ARG;
if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() ||
(fn->getReturnType()->isIntegerTy() &&
cast<IntegerType>(fn->getReturnType())->getBitWidth() < 16))
retType = DIFFE_TYPE::CONSTANT;

// TODO re atomic add consider forcing it to be atomic always as fallback if
// used in a parallel context
auto &augdata = Logic.CreateAugmentedPrimal(
fn, retType, /*constant_args*/ types, TLI, TA,
/*returnUsed*/ !fn->getReturnType()->isEmptyTy() &&
!fn->getReturnType()->isVoidTy(),
type_args, uncacheable_args, /*forceAnonymousTape*/ true, AtomicAdd,
PostOpt);
Constant *newf = Logic.CreatePrimalAndGradient(
fn, retType, /*constant_args*/ types, TLI, TA,
/*returnValue*/ false, /*dretPtr*/ false,
DerivativeMode::ReverseModeGradient,
/*additionalArg*/ Type::getInt8PtrTy(fn->getContext()), type_args,
uncacheable_args,
/*map*/ &augdata, AtomicAdd);
if (!newf)
newf = UndefValue::get(fn->getType());
auto cdata = ConstantStruct::get(
StructType::get(newf->getContext(),
{augdata.fn->getType(), newf->getType()}),
{augdata.fn, newf});
std::string globalname = ("_enzyme_" + fn->getName() + "'").str();
auto GV = fn->getParent()->getNamedValue(globalname);

if (GV == nullptr) {
GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true,
GlobalValue::LinkageTypes::InternalLinkage, cdata,
globalname);
}
return ConstantExpr::getPointerCast(GV, fn->getType());
}

Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
bool nullShadow) {
assert(oval);
Expand Down Expand Up @@ -2768,78 +2846,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
std::make_pair((const Value *)oval, InvertedPointerVH(this, cs)));
return cs;
} else if (auto fn = dyn_cast<Function>(oval)) {
//! Todo allow tape propagation
// Note that specifically this should _not_ be called with topLevel=true
// (since it may not be valid to always assume we can recompute the
// augmented primal) However, in the absence of a way to pass tape data
// from an indirect augmented (and also since we dont presently allow
// indirect augmented calls), topLevel MUST be true otherwise subcalls will
// not be able to lookup the augmenteddata/subdata (triggering an assertion
// failure, among much worse)
std::map<Argument *, bool> uncacheable_args;
FnTypeInfo type_args(fn);

// conservatively assume that we can only cache existing floating types
// (i.e. that all args are uncacheable)
std::vector<DIFFE_TYPE> types;
for (auto &a : fn->args()) {
uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy();
type_args.Arguments.insert(std::pair<Argument *, TypeTree>(&a, {}));
type_args.KnownValues.insert(
std::pair<Argument *, std::set<int64_t>>(&a, {}));
DIFFE_TYPE typ;
if (a.getType()->isFPOrFPVectorTy()) {
typ = DIFFE_TYPE::OUT_DIFF;
} else if (a.getType()->isIntegerTy() &&
cast<IntegerType>(a.getType())->getBitWidth() < 16) {
typ = DIFFE_TYPE::CONSTANT;
} else if (a.getType()->isVoidTy() || a.getType()->isEmptyTy()) {
typ = DIFFE_TYPE::CONSTANT;
} else {
typ = DIFFE_TYPE::DUP_ARG;
}
types.push_back(typ);
}

DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy()
? DIFFE_TYPE::OUT_DIFF
: DIFFE_TYPE::DUP_ARG;
if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() ||
(fn->getReturnType()->isIntegerTy() &&
cast<IntegerType>(fn->getReturnType())->getBitWidth() < 16))
retType = DIFFE_TYPE::CONSTANT;

// TODO re atomic add consider forcing it to be atomic always as fallback if
// used in a parallel context
auto &augdata = Logic.CreateAugmentedPrimal(
fn, retType, /*constant_args*/ types, TLI, TA,
/*returnUsed*/ !fn->getReturnType()->isEmptyTy() &&
!fn->getReturnType()->isVoidTy(),
type_args, uncacheable_args, /*forceAnonymousTape*/ true, AtomicAdd,
/*PostOpt*/ false);
Constant *newf = Logic.CreatePrimalAndGradient(
fn, retType, /*constant_args*/ types, TLI, TA,
/*returnValue*/ false, /*dretPtr*/ false,
DerivativeMode::ReverseModeGradient,
/*additionalArg*/ Type::getInt8PtrTy(fn->getContext()), type_args,
uncacheable_args,
/*map*/ &augdata, AtomicAdd);
if (!newf)
newf = UndefValue::get(fn->getType());
auto cdata = ConstantStruct::get(
StructType::get(newf->getContext(),
{augdata.fn->getType(), newf->getType()}),
{augdata.fn, newf});
std::string globalname = ("_enzyme_" + fn->getName() + "'").str();
auto GV = fn->getParent()->getNamedValue(globalname);

if (GV == nullptr) {
GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true,
GlobalValue::LinkageTypes::InternalLinkage, cdata,
globalname);
}

return BuilderM.CreatePointerCast(GV, fn->getType());
return GetOrCreateShadowFunction(Logic, TLI, TA, fn, AtomicAdd);
} else if (auto arg = dyn_cast<CastInst>(oval)) {
IRBuilder<> bb(getNewFromOriginal(arg));
Value *invertOp = invertPointerM(arg->getOperand(0), bb);
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,12 @@ class GradientUtils : public CacheUtility {
Value *invertPointerM(Value *val, IRBuilder<> &BuilderM,
bool nullShadow = false);

static Constant *GetOrCreateShadowFunction(EnzymeLogic &Logic,
TargetLibraryInfo &TLI,
TypeAnalysis &TA, Function *F,
bool AtomicAdd = true,
bool PostOpt = false);

void branchToCorrespondingTarget(
BasicBlock *ctx, IRBuilder<> &BuilderM,
const std::map<BasicBlock *,
Expand Down
49 changes: 49 additions & 0 deletions enzyme/test/Integration/ReverseMode/virtualshadow2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -

#include <stdio.h>
#include "test_utils.h"

struct S {
double (*fn)(double);
double val;
};

double square(double x){ return x * x; }

double foo(struct S* s) {
return square(s->val);
}


void primal() {
struct S s;
s.fn = square;
s.val = 3.0;
printf("%f\n", foo(&s));
}

void* __enzyme_virtualreverse(void*);
void __enzyme_autodiff(void*, void*, void*);
void reverse() {
struct S s;
s.fn = square;
s.val = 3.0;
struct S d_s;
d_s.fn = (double (*)(double))__enzyme_virtualreverse((void*)square);
d_s.val = 0.0;
__enzyme_autodiff((void*)foo, &s, &d_s);
printf("shadow res=%f\n", d_s.val);
APPROX_EQ(d_s.val, 6.0, 1e-7);
}

int main() {
primal();
reverse();
}

0 comments on commit dbb6552

Please sign in to comment.