diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index e7ae6c1d7fcb6..63730de7383d2 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -725,6 +725,7 @@ class Enzyme : public ModulePass { Fn->getName() == "__enzyme_double" || Fn->getName() == "__enzyme_integer" || Fn->getName() == "__enzyme_pointer" || + Fn->getName().contains("__enzyme_virtualreverse") || Fn->getName().contains("__enzyme_call_inactive") || Fn->getName().contains("__enzyme_autodiff") || Fn->getName().contains("__enzyme_fwddiff") || @@ -763,6 +764,7 @@ class Enzyme : public ModulePass { } std::map toLower; + std::map toVirtual; std::set InactiveCalls; std::set IterCalls; retry:; @@ -824,6 +826,10 @@ class Enzyme : public ModulePass { } } } + if (Fn->getName() == "__enzyme_virtualreverse") { + Fn->addFnAttr(Attribute::ReadNone); + CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); + } if (Fn->getName() == "__enzyme_iter") { Fn->addFnAttr(Attribute::ReadNone); CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); @@ -960,6 +966,7 @@ class Enzyme : public ModulePass { } bool enableEnzyme = false; + bool virtualCall = false; DerivativeMode mode; if (Fn->getName().contains("__enzyme_autodiff")) { enableEnzyme = true; @@ -973,10 +980,13 @@ class Enzyme : public ModulePass { } else if (Fn->getName().contains("__enzyme_reverse")) { enableEnzyme = true; mode = DerivativeMode::ReverseModeGradient; + } else if (Fn->getName().contains("__enzyme_virtualreverse")) { + enableEnzyme = true; + virtualCall = true; + mode = DerivativeMode::ReverseModeCombined; } if (enableEnzyme) { - toLower[CI] = mode; Value *fn = CI->getArgOperand(0); while (auto ci = dyn_cast(fn)) { @@ -1013,9 +1023,16 @@ class Enzyme : public ModulePass { } goto retry; } - if (auto dc = dyn_cast(fn)) + + if (virtualCall) + toVirtual[CI] = mode; + else + toLower[CI] = mode; + + if (auto dc = dyn_cast(fn)) { Changed |= lowerEnzymeCalls(*dc, /*PostOpt*/ true, successful, done); + } } } } @@ -1048,6 +1065,36 @@ class Enzyme : public ModulePass { break; } + for (auto pair : toVirtual) { + auto CI = pair.first; + Value *fn = CI->getArgOperand(0); + while (auto ci = dyn_cast(fn)) { + fn = ci->getOperand(0); + } + while (auto ci = dyn_cast(fn)) { + fn = ci->getFunction(); + } + while (auto ci = dyn_cast(fn)) { + fn = ci->getOperand(0); + } + auto F = cast(fn); + TypeAnalysis TA(TLI); + + auto Arch = + llvm::Triple( + CI->getParent()->getParent()->getParent()->getTargetTriple()) + .getArch(); + + bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 || + Arch == Triple::amdgcn; + + auto val = GradientUtils::GetOrCreateShadowFunction(Logic, TLI, TA, F, + AtomicAdd, PostOpt); + CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType())); + CI->eraseFromParent(); + Changed = true; + } + if (Changed) { // TODO consider enabling when attributor does not delete // dead internal functions, which invalidates Enzyme's cache @@ -1199,7 +1246,8 @@ class Enzyme : public ModulePass { for (Function &F : M) { if (F.getName() == "__enzyme_float" || F.getName() == "__enzyme_double" || F.getName() == "__enzyme_integer" || - F.getName() == "__enzyme_pointer") { + F.getName() == "__enzyme_pointer" || + F.getName().contains("__enzyme_virtualreverse")) { F.addFnAttr(Attribute::ReadNone); for (auto &arg : F.args()) { if (arg.getType()->isPointerTy()) { diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 536ee2d326b0d..00ad6971bacec 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -2487,6 +2487,84 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( return res; } +Constant *GradientUtils::GetOrCreateShadowFunction(EnzymeLogic &Logic, + TargetLibraryInfo &TLI, + TypeAnalysis &TA, + Function *fn, bool AtomicAdd, + bool PostOpt) { + //! Todo allow tape propagation + // Note that specifically this should _not_ be called with topLevel=true + // (since it may not be valid to always assume we can recompute the + // augmented primal) However, in the absence of a way to pass tape data + // from an indirect augmented (and also since we dont presently allow + // indirect augmented calls), topLevel MUST be true otherwise subcalls will + // not be able to lookup the augmenteddata/subdata (triggering an assertion + // failure, among much worse) + std::map uncacheable_args; + FnTypeInfo type_args(fn); + + // conservatively assume that we can only cache existing floating types + // (i.e. that all args are uncacheable) + std::vector types; + for (auto &a : fn->args()) { + uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy(); + type_args.Arguments.insert(std::pair(&a, {})); + type_args.KnownValues.insert( + std::pair>(&a, {})); + DIFFE_TYPE typ; + if (a.getType()->isFPOrFPVectorTy()) { + typ = DIFFE_TYPE::OUT_DIFF; + } else if (a.getType()->isIntegerTy() && + cast(a.getType())->getBitWidth() < 16) { + typ = DIFFE_TYPE::CONSTANT; + } else if (a.getType()->isVoidTy() || a.getType()->isEmptyTy()) { + typ = DIFFE_TYPE::CONSTANT; + } else { + typ = DIFFE_TYPE::DUP_ARG; + } + types.push_back(typ); + } + + DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy() + ? DIFFE_TYPE::OUT_DIFF + : DIFFE_TYPE::DUP_ARG; + if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() || + (fn->getReturnType()->isIntegerTy() && + cast(fn->getReturnType())->getBitWidth() < 16)) + retType = DIFFE_TYPE::CONSTANT; + + // TODO re atomic add consider forcing it to be atomic always as fallback if + // used in a parallel context + auto &augdata = Logic.CreateAugmentedPrimal( + fn, retType, /*constant_args*/ types, TLI, TA, + /*returnUsed*/ !fn->getReturnType()->isEmptyTy() && + !fn->getReturnType()->isVoidTy(), + type_args, uncacheable_args, /*forceAnonymousTape*/ true, AtomicAdd, + PostOpt); + Constant *newf = Logic.CreatePrimalAndGradient( + fn, retType, /*constant_args*/ types, TLI, TA, + /*returnValue*/ false, /*dretPtr*/ false, + DerivativeMode::ReverseModeGradient, + /*additionalArg*/ Type::getInt8PtrTy(fn->getContext()), type_args, + uncacheable_args, + /*map*/ &augdata, AtomicAdd); + if (!newf) + newf = UndefValue::get(fn->getType()); + auto cdata = ConstantStruct::get( + StructType::get(newf->getContext(), + {augdata.fn->getType(), newf->getType()}), + {augdata.fn, newf}); + std::string globalname = ("_enzyme_" + fn->getName() + "'").str(); + auto GV = fn->getParent()->getNamedValue(globalname); + + if (GV == nullptr) { + GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true, + GlobalValue::LinkageTypes::InternalLinkage, cdata, + globalname); + } + return ConstantExpr::getPointerCast(GV, fn->getType()); +} + Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, bool nullShadow) { assert(oval); @@ -2768,78 +2846,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, std::make_pair((const Value *)oval, InvertedPointerVH(this, cs))); return cs; } else if (auto fn = dyn_cast(oval)) { - //! Todo allow tape propagation - // Note that specifically this should _not_ be called with topLevel=true - // (since it may not be valid to always assume we can recompute the - // augmented primal) However, in the absence of a way to pass tape data - // from an indirect augmented (and also since we dont presently allow - // indirect augmented calls), topLevel MUST be true otherwise subcalls will - // not be able to lookup the augmenteddata/subdata (triggering an assertion - // failure, among much worse) - std::map uncacheable_args; - FnTypeInfo type_args(fn); - - // conservatively assume that we can only cache existing floating types - // (i.e. that all args are uncacheable) - std::vector types; - for (auto &a : fn->args()) { - uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy(); - type_args.Arguments.insert(std::pair(&a, {})); - type_args.KnownValues.insert( - std::pair>(&a, {})); - DIFFE_TYPE typ; - if (a.getType()->isFPOrFPVectorTy()) { - typ = DIFFE_TYPE::OUT_DIFF; - } else if (a.getType()->isIntegerTy() && - cast(a.getType())->getBitWidth() < 16) { - typ = DIFFE_TYPE::CONSTANT; - } else if (a.getType()->isVoidTy() || a.getType()->isEmptyTy()) { - typ = DIFFE_TYPE::CONSTANT; - } else { - typ = DIFFE_TYPE::DUP_ARG; - } - types.push_back(typ); - } - - DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy() - ? DIFFE_TYPE::OUT_DIFF - : DIFFE_TYPE::DUP_ARG; - if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() || - (fn->getReturnType()->isIntegerTy() && - cast(fn->getReturnType())->getBitWidth() < 16)) - retType = DIFFE_TYPE::CONSTANT; - - // TODO re atomic add consider forcing it to be atomic always as fallback if - // used in a parallel context - auto &augdata = Logic.CreateAugmentedPrimal( - fn, retType, /*constant_args*/ types, TLI, TA, - /*returnUsed*/ !fn->getReturnType()->isEmptyTy() && - !fn->getReturnType()->isVoidTy(), - type_args, uncacheable_args, /*forceAnonymousTape*/ true, AtomicAdd, - /*PostOpt*/ false); - Constant *newf = Logic.CreatePrimalAndGradient( - fn, retType, /*constant_args*/ types, TLI, TA, - /*returnValue*/ false, /*dretPtr*/ false, - DerivativeMode::ReverseModeGradient, - /*additionalArg*/ Type::getInt8PtrTy(fn->getContext()), type_args, - uncacheable_args, - /*map*/ &augdata, AtomicAdd); - if (!newf) - newf = UndefValue::get(fn->getType()); - auto cdata = ConstantStruct::get( - StructType::get(newf->getContext(), - {augdata.fn->getType(), newf->getType()}), - {augdata.fn, newf}); - std::string globalname = ("_enzyme_" + fn->getName() + "'").str(); - auto GV = fn->getParent()->getNamedValue(globalname); - - if (GV == nullptr) { - GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true, - GlobalValue::LinkageTypes::InternalLinkage, cdata, - globalname); - } - - return BuilderM.CreatePointerCast(GV, fn->getType()); + return GetOrCreateShadowFunction(Logic, TLI, TA, fn, AtomicAdd); } else if (auto arg = dyn_cast(oval)) { IRBuilder<> bb(getNewFromOriginal(arg)); Value *invertOp = invertPointerM(arg->getOperand(0), bb); diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index e1a6d1ec07ba0..37a9980b3bcb1 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -1323,6 +1323,12 @@ class GradientUtils : public CacheUtility { Value *invertPointerM(Value *val, IRBuilder<> &BuilderM, bool nullShadow = false); + static Constant *GetOrCreateShadowFunction(EnzymeLogic &Logic, + TargetLibraryInfo &TLI, + TypeAnalysis &TA, Function *F, + bool AtomicAdd = true, + bool PostOpt = false); + void branchToCorrespondingTarget( BasicBlock *ctx, IRBuilder<> &BuilderM, const std::map +#include "test_utils.h" + +struct S { + double (*fn)(double); + double val; +}; + +double square(double x){ return x * x; } + +double foo(struct S* s) { + return square(s->val); +} + + +void primal() { + struct S s; + s.fn = square; + s.val = 3.0; + printf("%f\n", foo(&s)); +} + +void* __enzyme_virtualreverse(void*); +void __enzyme_autodiff(void*, void*, void*); +void reverse() { + struct S s; + s.fn = square; + s.val = 3.0; + struct S d_s; + d_s.fn = (double (*)(double))__enzyme_virtualreverse((void*)square); + d_s.val = 0.0; + __enzyme_autodiff((void*)foo, &s, &d_s); + printf("shadow res=%f\n", d_s.val); + APPROX_EQ(d_s.val, 6.0, 1e-7); +} + +int main() { + primal(); + reverse(); +}