From fe2f4452c58229c363b4981f95d583a8ec0ca184 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 3 Sep 2021 22:01:59 +0200 Subject: [PATCH] Forward Mode CallInst known functions (#299) * forward mode call inst known functions * add tests * fix tests * fix julia.write_barrier * fix _ZSt29_Rb_tree_insert_and_rebalancebPSt18_Rb_tree_ * fix __dynamic_cast * Alloc / dealloc * fix else condition * fix pointer for free and add runtime check to prevent double free. * posix_memalign --- enzyme/Enzyme/AdjointGenerator.h | 1075 +++++++++++++++++------- enzyme/test/Enzyme/ForwardMode/cosh.ll | 28 + enzyme/test/Enzyme/ForwardMode/erf.ll | 29 + enzyme/test/Enzyme/ForwardMode/erfc.ll | 29 + enzyme/test/Enzyme/ForwardMode/erfi.ll | 28 + 5 files changed, 871 insertions(+), 318 deletions(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/cosh.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/erf.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/erfc.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/erfi.ll diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 9a3e5ea953c6d..937f2f7773491 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -5673,27 +5673,53 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - Value *oneMx2 = Builder2.CreateFSub(ConstantFP::get(x->getType(), 1.0), - Builder2.CreateFMul(x, x)); + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + Value *oneMx2 = Builder2.CreateFSub( + ConstantFP::get(x->getType(), 1.0), Builder2.CreateFMul(x, x)); + + SmallVector args = {oneMx2}; + Type *tys[] = {x->getType()}; + auto cal = cast(Builder2.CreateCall( + Intrinsic::getDeclaration(called->getParent(), Intrinsic::sqrt, + tys), + args)); - SmallVector args = {oneMx2}; - Type *tys[] = {x->getType()}; - auto cal = cast( - Builder2.CreateCall(Intrinsic::getDeclaration(called->getParent(), - Intrinsic::sqrt, tys), - args)); + Value *dif0 = + Builder2.CreateFDiv(diffe(orig->getArgOperand(0), Builder2), cal); + setDiffe(orig, dif0, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), + Builder2); + Value *oneMx2 = Builder2.CreateFSub( + ConstantFP::get(x->getType(), 1.0), Builder2.CreateFMul(x, x)); + + SmallVector args = {oneMx2}; + Type *tys[] = {x->getType()}; + auto cal = cast(Builder2.CreateCall( + Intrinsic::getDeclaration(called->getParent(), Intrinsic::sqrt, + tys), + args)); - Value *dif0 = Builder2.CreateFDiv(diffe(orig, Builder2), cal); - addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); - return; + Value *dif0 = Builder2.CreateFDiv(diffe(orig, Builder2), cal); + addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "atan" || funcName == "atanf" || funcName == "atanl" || @@ -5706,19 +5732,37 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - Value *onePx2 = Builder2.CreateFAdd(ConstantFP::get(x->getType(), 1.0), - Builder2.CreateFMul(x, x)); - Value *dif0 = Builder2.CreateFDiv(diffe(orig, Builder2), onePx2); - addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + Value *onePx2 = Builder2.CreateFAdd( + ConstantFP::get(x->getType(), 1.0), Builder2.CreateFMul(x, x)); + Value *dif0 = Builder2.CreateFDiv( + diffe(orig->getArgOperand(0), Builder2), onePx2); + setDiffe(orig, dif0, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), + Builder2); + Value *onePx2 = Builder2.CreateFAdd( + ConstantFP::get(x->getType(), 1.0), Builder2.CreateFMul(x, x)); + Value *dif0 = Builder2.CreateFDiv(diffe(orig, Builder2), onePx2); + addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "cbrt") { @@ -5730,29 +5774,59 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - Value *args[] = {x}; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + Value *args[] = {x}; #if LLVM_VERSION_MAJOR >= 11 - auto callval = orig->getCalledOperand(); + auto callval = orig->getCalledOperand(); #else - auto callval = orig->getCalledValue(); + auto callval = orig->getCalledValue(); #endif - CallInst *cubcall = cast( - Builder2.CreateCall(orig->getFunctionType(), callval, args)); - cubcall->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc())); - cubcall->setCallingConv(orig->getCallingConv()); - Value *dif0 = Builder2.CreateFDiv( - Builder2.CreateFMul(diffe(orig, Builder2), cubcall), - Builder2.CreateFMul(ConstantFP::get(x->getType(), 3), x)); - addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); - return; + CallInst *cubcall = cast( + Builder2.CreateCall(orig->getFunctionType(), callval, args)); + cubcall->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc())); + cubcall->setCallingConv(orig->getCallingConv()); + Value *dif0 = Builder2.CreateFDiv( + Builder2.CreateFMul(diffe(orig->getArgOperand(0), Builder2), + cubcall), + Builder2.CreateFMul(ConstantFP::get(x->getType(), 3), x)); + setDiffe(orig, dif0, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + + Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), + Builder2); + Value *args[] = {x}; +#if LLVM_VERSION_MAJOR >= 11 + auto callval = orig->getCalledOperand(); +#else + auto callval = orig->getCalledValue(); +#endif + CallInst *cubcall = cast( + Builder2.CreateCall(orig->getFunctionType(), callval, args)); + cubcall->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc())); + cubcall->setCallingConv(orig->getCallingConv()); + Value *dif0 = Builder2.CreateFDiv( + Builder2.CreateFMul(diffe(orig, Builder2), cubcall), + Builder2.CreateFMul(ConstantFP::get(x->getType(), 3), x)); + addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "tanhf" || funcName == "tanh") { @@ -5764,25 +5838,48 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - SmallVector args = {x}; - auto coshf = gutils->oldFunc->getParent()->getOrInsertFunction( - (funcName == "tanh") ? "cosh" : "coshf", called->getFunctionType(), - called->getAttributes()); - auto cal = cast(Builder2.CreateCall(coshf, args)); - Value *dif0 = Builder2.CreateFDiv(diffe(orig, Builder2), - Builder2.CreateFMul(cal, cal)); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + + SmallVector args = {x}; + auto coshf = gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName == "tanh") ? "cosh" : "coshf", + called->getFunctionType(), called->getAttributes()); + auto cal = cast(Builder2.CreateCall(coshf, args)); + Value *dif0 = + Builder2.CreateFDiv(diffe(orig->getArgOperand(0), Builder2), + Builder2.CreateFMul(cal, cal)); + setDiffe(orig, dif0, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), + Builder2); + + SmallVector args = {x}; + auto coshf = gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName == "tanh") ? "cosh" : "coshf", + called->getFunctionType(), called->getAttributes()); + auto cal = cast(Builder2.CreateCall(coshf, args)); + Value *dif0 = Builder2.CreateFDiv(diffe(orig, Builder2), + Builder2.CreateFMul(cal, cal)); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "coshf" || funcName == "cosh") { @@ -5794,24 +5891,46 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - SmallVector args = {x}; - auto sinhf = gutils->oldFunc->getParent()->getOrInsertFunction( - (funcName == "cosh") ? "sinh" : "sinhf", called->getFunctionType(), - called->getAttributes()); - auto cal = cast(Builder2.CreateCall(sinhf, args)); - Value *dif0 = Builder2.CreateFMul(diffe(orig, Builder2), cal); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + + SmallVector args = {x}; + auto sinhf = gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName == "cosh") ? "sinh" : "sinhf", + called->getFunctionType(), called->getAttributes()); + auto cal = cast(Builder2.CreateCall(sinhf, args)); + Value *dif0 = + Builder2.CreateFMul(diffe(orig->getArgOperand(0), Builder2), cal); + setDiffe(orig, dif0, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), + Builder2); + + SmallVector args = {x}; + auto sinhf = gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName == "cosh") ? "sinh" : "sinhf", + called->getFunctionType(), called->getAttributes()); + auto cal = cast(Builder2.CreateCall(sinhf, args)); + Value *dif0 = Builder2.CreateFMul(diffe(orig, Builder2), cal); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "sinhf" || funcName == "sinh") { if (gutils->knownRecomputeHeuristic.find(orig) != @@ -5822,24 +5941,46 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - SmallVector args = {x}; - auto sinhf = gutils->oldFunc->getParent()->getOrInsertFunction( - (funcName == "sinh") ? "cosh" : "coshf", called->getFunctionType(), - called->getAttributes()); - auto cal = cast(Builder2.CreateCall(sinhf, args)); - Value *dif0 = Builder2.CreateFMul(diffe(orig, Builder2), cal); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + + SmallVector args = {x}; + auto sinhf = gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName == "sinh") ? "cosh" : "coshf", + called->getFunctionType(), called->getAttributes()); + auto cal = cast(Builder2.CreateCall(sinhf, args)); + Value *dif0 = + Builder2.CreateFMul(diffe(orig->getArgOperand(0), Builder2), cal); + setDiffe(orig, dif0, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), + Builder2); + + SmallVector args = {x}; + auto sinhf = gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName == "sinh") ? "cosh" : "coshf", + called->getFunctionType(), called->getAttributes()); + auto cal = cast(Builder2.CreateCall(sinhf, args)); + Value *dif0 = Builder2.CreateFMul(diffe(orig, Builder2), cal); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } // Functions that only modify pointers and don't allocate memory, @@ -5887,7 +6028,8 @@ class AdjointGenerator if (subretType == DIFFE_TYPE::DUP_ARG) { Value *shadow = placeholder; if (lrc || Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined) { + Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ForwardMode) { if (gutils->isConstantValue(orig->getArgOperand(0))) shadow = gutils->getNewFromOriginal(orig); else { @@ -5925,6 +6067,12 @@ class AdjointGenerator } } + if (Mode == DerivativeMode::ForwardMode) { + eraseIfUnused(*orig); + assert(gutils->isConstantInstruction(orig)); + return; + } + if (!shouldCache && !lrc) { std::map Seen; for (auto pair : gutils->knownRecomputeHeuristic) @@ -5955,29 +6103,56 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - Value *sq = Builder2.CreateFNeg(Builder2.CreateFMul(x, x)); - Type *tys[] = {sq->getType()}; - Function *ExpF = Intrinsic::getDeclaration( - gutils->oldFunc->getParent(), Intrinsic::exp, tys); - Value *cal = Builder2.CreateCall(ExpF, std::vector({sq})); - - cal = Builder2.CreateFMul( - cal, ConstantFP::get( - sq->getType(), - 1.1283791670955125738961589031215451716881012586580)); - cal = Builder2.CreateFMul(cal, diffe(orig, Builder2)); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), cal, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + + Value *sq = Builder2.CreateFNeg(Builder2.CreateFMul(x, x)); + Type *tys[] = {sq->getType()}; + Function *ExpF = Intrinsic::getDeclaration( + gutils->oldFunc->getParent(), Intrinsic::exp, tys); + Value *cal = Builder2.CreateCall(ExpF, std::vector({sq})); + + cal = Builder2.CreateFMul( + cal, ConstantFP::get( + sq->getType(), + 1.1283791670955125738961589031215451716881012586580)); + cal = Builder2.CreateFMul(cal, + diffe(orig->getArgOperand(0), Builder2)); + setDiffe(orig, cal, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(0)), Builder2); + + Value *sq = Builder2.CreateFNeg(Builder2.CreateFMul(x, x)); + Type *tys[] = {sq->getType()}; + Function *ExpF = Intrinsic::getDeclaration( + gutils->oldFunc->getParent(), Intrinsic::exp, tys); + Value *cal = Builder2.CreateCall(ExpF, std::vector({sq})); + + cal = Builder2.CreateFMul( + cal, ConstantFP::get( + sq->getType(), + 1.1283791670955125738961589031215451716881012586580)); + cal = Builder2.CreateFMul(cal, diffe(orig, Builder2)); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), cal, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "erfi") { if (gutils->knownRecomputeHeuristic.find(orig) != @@ -5988,29 +6163,55 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - Value *sq = Builder2.CreateFMul(x, x); - Type *tys[] = {sq->getType()}; - Function *ExpF = Intrinsic::getDeclaration( - gutils->oldFunc->getParent(), Intrinsic::exp, tys); - Value *cal = Builder2.CreateCall(ExpF, std::vector({sq})); - - cal = Builder2.CreateFMul( - cal, ConstantFP::get( - sq->getType(), - 1.1283791670955125738961589031215451716881012586580)); - cal = Builder2.CreateFMul(cal, diffe(orig, Builder2)); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), cal, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + + Value *sq = Builder2.CreateFMul(x, x); + Type *tys[] = {sq->getType()}; + Function *ExpF = Intrinsic::getDeclaration( + gutils->oldFunc->getParent(), Intrinsic::exp, tys); + Value *cal = Builder2.CreateCall(ExpF, std::vector({sq})); + + cal = Builder2.CreateFMul( + cal, ConstantFP::get( + sq->getType(), + 1.1283791670955125738961589031215451716881012586580)); + cal = Builder2.CreateFMul(cal, + diffe(orig->getArgOperand(0), Builder2)); + setDiffe(orig, cal, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(0)), Builder2); + + Value *sq = Builder2.CreateFMul(x, x); + Type *tys[] = {sq->getType()}; + Function *ExpF = Intrinsic::getDeclaration( + gutils->oldFunc->getParent(), Intrinsic::exp, tys); + Value *cal = Builder2.CreateCall(ExpF, std::vector({sq})); + + cal = Builder2.CreateFMul( + cal, ConstantFP::get( + sq->getType(), + 1.1283791670955125738961589031215451716881012586580)); + cal = Builder2.CreateFMul(cal, diffe(orig, Builder2)); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), cal, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: + return; + } } if (funcName == "erfc") { if (gutils->knownRecomputeHeuristic.find(orig) != @@ -6021,29 +6222,54 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - Value *sq = Builder2.CreateFNeg(Builder2.CreateFMul(x, x)); - Type *tys[] = {sq->getType()}; - Function *ExpF = Intrinsic::getDeclaration( - gutils->oldFunc->getParent(), Intrinsic::exp, tys); - Value *cal = Builder2.CreateCall(ExpF, std::vector({sq})); - - cal = Builder2.CreateFMul( - cal, ConstantFP::get( - sq->getType(), - -1.1283791670955125738961589031215451716881012586580)); - cal = Builder2.CreateFMul(cal, diffe(orig, Builder2)); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), cal, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + Value *sq = Builder2.CreateFNeg(Builder2.CreateFMul(x, x)); + Type *tys[] = {sq->getType()}; + Function *ExpF = Intrinsic::getDeclaration( + gutils->oldFunc->getParent(), Intrinsic::exp, tys); + Value *cal = Builder2.CreateCall(ExpF, std::vector({sq})); + + cal = Builder2.CreateFMul( + cal, ConstantFP::get( + sq->getType(), + -1.1283791670955125738961589031215451716881012586580)); + cal = Builder2.CreateFMul(cal, + diffe(orig->getArgOperand(0), Builder2)); + setDiffe(orig, cal, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(0)), Builder2); + + Value *sq = Builder2.CreateFNeg(Builder2.CreateFMul(x, x)); + Type *tys[] = {sq->getType()}; + Function *ExpF = Intrinsic::getDeclaration( + gutils->oldFunc->getParent(), Intrinsic::exp, tys); + Value *cal = Builder2.CreateCall(ExpF, std::vector({sq})); + + cal = Builder2.CreateFMul( + cal, ConstantFP::get( + sq->getType(), + -1.1283791670955125738961589031215451716881012586580)); + cal = Builder2.CreateFMul(cal, diffe(orig, Builder2)); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), cal, Builder2, x->getType()); + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "j0" || funcName == "y0" || funcName == "j0f" || @@ -6056,26 +6282,50 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - Value *dx = Builder2.CreateCall( - gutils->oldFunc->getParent()->getOrInsertFunction( - (funcName[0] == 'j') ? ((funcName == "j0") ? "j1" : "j1f") - : ((funcName == "y0") ? "y1" : "y1f"), - called->getFunctionType()), - std::vector({x})); - dx = Builder2.CreateFNeg(dx); - dx = Builder2.CreateFMul(dx, diffe(orig, Builder2)); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), dx, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + + Value *dx = Builder2.CreateCall( + gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName[0] == 'j') ? ((funcName == "j0") ? "j1" : "j1f") + : ((funcName == "y0") ? "y1" : "y1f"), + called->getFunctionType()), + std::vector({x})); + dx = Builder2.CreateFNeg(dx); + dx = Builder2.CreateFMul(dx, + diffe(orig->getArgOperand(0), Builder2)); + setDiffe(orig, dx, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(0)), Builder2); + + Value *dx = Builder2.CreateCall( + gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName[0] == 'j') ? ((funcName == "j0") ? "j1" : "j1f") + : ((funcName == "y0") ? "y1" : "y1f"), + called->getFunctionType()), + std::vector({x})); + dx = Builder2.CreateFNeg(dx); + dx = Builder2.CreateFMul(dx, diffe(orig, Builder2)); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), dx, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "j1" || funcName == "y1" || funcName == "j1f" || @@ -6088,38 +6338,74 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - Value *d0 = Builder2.CreateCall( - gutils->oldFunc->getParent()->getOrInsertFunction( - (funcName[0] == 'j') ? ((funcName == "j1") ? "j0" : "j0f") - : ((funcName == "y1") ? "y0" : "y0f"), - called->getFunctionType()), - std::vector({x})); - - Type *intType = - Type::getIntNTy(called->getContext(), sizeof(int) * 8); - Type *pargs[] = {intType, x->getType()}; - auto FT2 = FunctionType::get(x->getType(), pargs, false); - Value *d2 = Builder2.CreateCall( - gutils->oldFunc->getParent()->getOrInsertFunction( - (funcName[0] == 'j') ? ((funcName == "j1") ? "jn" : "jnf") - : ((funcName == "y1") ? "yn" : "ynf"), - FT2), - std::vector({ConstantInt::get(intType, 2), x})); - Value *dx = Builder2.CreateFSub(d0, d2); - dx = Builder2.CreateFMul(dx, ConstantFP::get(x->getType(), 0.5)); - dx = Builder2.CreateFMul(dx, diffe(orig, Builder2)); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), dx, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + + Value *d0 = Builder2.CreateCall( + gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName[0] == 'j') ? ((funcName == "j1") ? "j0" : "j0f") + : ((funcName == "y1") ? "y0" : "y0f"), + called->getFunctionType()), + std::vector({x})); + + Type *intType = + Type::getIntNTy(called->getContext(), sizeof(int) * 8); + Type *pargs[] = {intType, x->getType()}; + auto FT2 = FunctionType::get(x->getType(), pargs, false); + Value *d2 = Builder2.CreateCall( + gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName[0] == 'j') ? ((funcName == "j1") ? "jn" : "jnf") + : ((funcName == "y1") ? "yn" : "ynf"), + FT2), + std::vector({ConstantInt::get(intType, 2), x})); + Value *dx = Builder2.CreateFSub(d0, d2); + dx = Builder2.CreateFMul(dx, ConstantFP::get(x->getType(), 0.5)); + dx = Builder2.CreateFMul(dx, + diffe(orig->getArgOperand(0), Builder2)); + setDiffe(orig, dx, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(0)), Builder2); + + Value *d0 = Builder2.CreateCall( + gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName[0] == 'j') ? ((funcName == "j1") ? "j0" : "j0f") + : ((funcName == "y1") ? "y0" : "y0f"), + called->getFunctionType()), + std::vector({x})); + + Type *intType = + Type::getIntNTy(called->getContext(), sizeof(int) * 8); + Type *pargs[] = {intType, x->getType()}; + auto FT2 = FunctionType::get(x->getType(), pargs, false); + Value *d2 = Builder2.CreateCall( + gutils->oldFunc->getParent()->getOrInsertFunction( + (funcName[0] == 'j') ? ((funcName == "j1") ? "jn" : "jnf") + : ((funcName == "y1") ? "yn" : "ynf"), + FT2), + std::vector({ConstantInt::get(intType, 2), x})); + Value *dx = Builder2.CreateFSub(d0, d2); + dx = Builder2.CreateFMul(dx, ConstantFP::get(x->getType(), 0.5)); + dx = Builder2.CreateFMul(dx, diffe(orig, Builder2)); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), dx, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "jn" || funcName == "yn" || funcName == "jnf" || @@ -6132,35 +6418,67 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) + if (gutils->isConstantInstruction(orig)) return; - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(1)), - Builder2); - Value *n = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - Value *d0 = Builder2.CreateCall( - called, - std::vector( - {Builder2.CreateSub(n, ConstantInt::get(n->getType(), 1)), - x})); - - Value *d2 = Builder2.CreateCall( - called, - std::vector( - {Builder2.CreateAdd(n, ConstantInt::get(n->getType(), 1)), - x})); - - Value *dx = Builder2.CreateFSub(d0, d2); - dx = Builder2.CreateFMul(dx, ConstantFP::get(x->getType(), 0.5)); - dx = Builder2.CreateFMul(dx, diffe(orig, Builder2)); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(1), dx, Builder2, x->getType()); - return; + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(1)); + Value *n = gutils->getNewFromOriginal(orig->getArgOperand(0)); + + Value *d0 = Builder2.CreateCall( + called, + std::vector( + {Builder2.CreateSub(n, ConstantInt::get(n->getType(), 1)), + x})); + + Value *d2 = Builder2.CreateCall( + called, + std::vector( + {Builder2.CreateAdd(n, ConstantInt::get(n->getType(), 1)), + x})); + + Value *dx = Builder2.CreateFSub(d0, d2); + dx = Builder2.CreateFMul(dx, ConstantFP::get(x->getType(), 0.5)); + dx = Builder2.CreateFMul(dx, + diffe(orig->getArgOperand(1), Builder2)); + setDiffe(orig, dx, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *x = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(1)), Builder2); + Value *n = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(0)), Builder2); + + Value *d0 = Builder2.CreateCall( + called, + std::vector( + {Builder2.CreateSub(n, ConstantInt::get(n->getType(), 1)), + x})); + + Value *d2 = Builder2.CreateCall( + called, + std::vector( + {Builder2.CreateAdd(n, ConstantInt::get(n->getType(), 1)), + x})); + + Value *dx = Builder2.CreateFSub(d0, d2); + dx = Builder2.CreateFMul(dx, ConstantFP::get(x->getType(), 0.5)); + dx = Builder2.CreateFMul(dx, diffe(orig, Builder2)); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(1), dx, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "julia.write_barrier") { @@ -6208,38 +6526,71 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) { + if (gutils->isConstantInstruction(orig)) { return; } - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - - Value *vdiff = diffe(orig, Builder2); - Value *x = lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2); - - Value *args[] = {x}; - - Type *tys[] = {orig->getOperand(0)->getType()}; - CallInst *dsin = cast(Builder2.CreateCall( - Intrinsic::getDeclaration(gutils->oldFunc->getParent(), - Intrinsic::cos, tys), - args)); - CallInst *dcos = cast(Builder2.CreateCall( - Intrinsic::getDeclaration(gutils->oldFunc->getParent(), - Intrinsic::sin, tys), - args)); - Value *dif0 = Builder2.CreateFSub( - Builder2.CreateFMul(Builder2.CreateExtractValue(vdiff, {0}), - dsin), - Builder2.CreateFMul(Builder2.CreateExtractValue(vdiff, {1}), - dcos)); + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + + Value *vdiff = diffe(orig->getArgOperand(0), Builder2); + Value *x = gutils->getNewFromOriginal(orig->getArgOperand(0)); + Value *args[] = {x}; + + Type *tys[] = {orig->getOperand(0)->getType()}; + CallInst *dsin = cast(Builder2.CreateCall( + Intrinsic::getDeclaration(gutils->oldFunc->getParent(), + Intrinsic::cos, tys), + args)); + CallInst *dcos = cast(Builder2.CreateCall( + Intrinsic::getDeclaration(gutils->oldFunc->getParent(), + Intrinsic::sin, tys), + args)); + Value *dif0 = Builder2.CreateFSub( + Builder2.CreateFMul(Builder2.CreateExtractValue(vdiff, {0}), + dsin), + Builder2.CreateFMul(Builder2.CreateExtractValue(vdiff, {1}), + dcos)); + + setDiffe(orig, dif0, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); - return; + Value *vdiff = diffe(orig, Builder2); + Value *x = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(0)), Builder2); + + Value *args[] = {x}; + + Type *tys[] = {orig->getOperand(0)->getType()}; + CallInst *dsin = cast(Builder2.CreateCall( + Intrinsic::getDeclaration(gutils->oldFunc->getParent(), + Intrinsic::cos, tys), + args)); + CallInst *dcos = cast(Builder2.CreateCall( + Intrinsic::getDeclaration(gutils->oldFunc->getParent(), + Intrinsic::sin, tys), + args)); + Value *dif0 = Builder2.CreateFSub( + Builder2.CreateFMul(Builder2.CreateExtractValue(vdiff, {0}), + dsin), + Builder2.CreateFMul(Builder2.CreateExtractValue(vdiff, {1}), + dcos)); + + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } if (funcName == "cabs" || funcName == "cabsf" || funcName == "cabsl") { if (gutils->knownRecomputeHeuristic.find(orig) != @@ -6250,37 +6601,72 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) { + if (gutils->isConstantInstruction(orig)) { return; } - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); - Value *vdiff = diffe(orig, Builder2); + SmallVector args; + for (size_t i = 0; i < orig->getNumArgOperands(); ++i) + args.push_back( + gutils->getNewFromOriginal(orig->getArgOperand(i))); - SmallVector args; - for (size_t i = 0; i < orig->getNumArgOperands(); ++i) - args.push_back(lookup( - gutils->getNewFromOriginal(orig->getArgOperand(i)), Builder2)); + CallInst *d = cast(Builder2.CreateCall(called, args)); - CallInst *d = cast(Builder2.CreateCall(called, args)); + if (args.size() == 2) { + Value *dif1 = Builder2.CreateFMul( + args[0], Builder2.CreateFDiv( + diffe(orig->getArgOperand(0), Builder2), d)); - Value *div = Builder2.CreateFDiv(vdiff, d); + Value *dif2 = Builder2.CreateFMul( + args[1], Builder2.CreateFDiv( + diffe(orig->getArgOperand(1), Builder2), d)); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + setDiffe(orig, Builder2.CreateFAdd(dif1, dif2), Builder2); + return; + } else { + llvm::errs() << *orig << "\n"; + llvm_unreachable("unknown calling convention found for cabs"); + } + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + + Value *vdiff = diffe(orig, Builder2); + + SmallVector args; + for (size_t i = 0; i < orig->getNumArgOperands(); ++i) + args.push_back( + lookup(gutils->getNewFromOriginal(orig->getArgOperand(i)), + Builder2)); + + CallInst *d = cast(Builder2.CreateCall(called, args)); - if (args.size() == 2) { - for (int i = 0; i < 2; i++) - if (!gutils->isConstantValue(orig->getArgOperand(i))) - addToDiffe(orig->getArgOperand(i), - Builder2.CreateFMul(args[i], div), Builder2, - orig->getType()); + Value *div = Builder2.CreateFDiv(vdiff, d); + + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + + if (args.size() == 2) { + for (int i = 0; i < 2; i++) + if (!gutils->isConstantValue(orig->getArgOperand(i))) + addToDiffe(orig->getArgOperand(i), + Builder2.CreateFMul(args[i], div), Builder2, + orig->getType()); + return; + } else { + llvm::errs() << *orig << "\n"; + llvm_unreachable("unknown calling convention found for cabs"); + } + } + case DerivativeMode::ReverseModePrimal: { return; - } else { - llvm::errs() << *orig << "\n"; - llvm_unreachable("unknown calling convention found for cabs"); + } } } if (funcName == "ldexp" || funcName == "ldexpf" || @@ -6293,24 +6679,45 @@ class AdjointGenerator } } eraseIfUnused(*orig); - if (Mode == DerivativeMode::ReverseModePrimal || - gutils->isConstantInstruction(orig)) { + if (gutils->isConstantInstruction(orig)) { return; } - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); + switch (Mode) { + case DerivativeMode::ForwardMode: { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); - Value *vdiff = diffe(orig, Builder2); - Value *exponent = lookup( - gutils->getNewFromOriginal(orig->getArgOperand(1)), Builder2); + Value *vdiff = diffe(orig->getArgOperand(0), Builder2); + Value *exponent = + gutils->getNewFromOriginal(orig->getArgOperand(1)); - Value *args[] = {vdiff, exponent}; + Value *args[] = {vdiff, exponent}; - CallInst *darg = cast(Builder2.CreateCall(called, args)); - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - addToDiffe(orig->getArgOperand(0), darg, Builder2, orig->getType()); - return; + CallInst *darg = cast(Builder2.CreateCall(called, args)); + setDiffe(orig, darg, Builder2); + return; + } + case DerivativeMode::ReverseModeGradient: + case DerivativeMode::ReverseModeCombined: { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + + Value *vdiff = diffe(orig, Builder2); + Value *exponent = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(1)), Builder2); + + Value *args[] = {vdiff, exponent}; + + CallInst *darg = cast(Builder2.CreateCall(called, args)); + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + addToDiffe(orig->getArgOperand(0), darg, Builder2, orig->getType()); + return; + } + case DerivativeMode::ReverseModePrimal: { + return; + } + } } } @@ -6338,24 +6745,37 @@ class AdjointGenerator bool constval = gutils->isConstantValue(orig); if (!constval) { - auto anti = - gutils->createAntiMalloc(orig, getIndex(orig, CacheType::Shadow)); if (Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModeGradient) { - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - Value *tofree = lookup(anti, Builder2); - assert(tofree); - assert(tofree->getType()); - assert(Type::getInt8Ty(tofree->getContext())); - assert(PointerType::getUnqual(Type::getInt8Ty(tofree->getContext()))); - assert(Type::getInt8PtrTy(tofree->getContext())); - auto dbgLoc = gutils->getNewFromOriginal(orig)->getDebugLoc(); - auto CI = freeKnownAllocation(Builder2, tofree, *called, dbgLoc, - gutils->TLI); - if (CI) { - CI->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull); + Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ReverseModePrimal) { + auto anti = + gutils->createAntiMalloc(orig, getIndex(orig, CacheType::Shadow)); + if (Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ReverseModeGradient) { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + Value *tofree = lookup(anti, Builder2); + assert(tofree); + assert(tofree->getType()); + assert(Type::getInt8Ty(tofree->getContext())); + assert( + PointerType::getUnqual(Type::getInt8Ty(tofree->getContext()))); + assert(Type::getInt8PtrTy(tofree->getContext())); + auto dbgLoc = gutils->getNewFromOriginal(orig)->getDebugLoc(); + auto CI = freeKnownAllocation(Builder2, tofree, *called, dbgLoc, + gutils->TLI); + if (CI) + CI->addAttribute(AttributeList::FirstArgIndex, + Attribute::NonNull); } + } else if (Mode == DerivativeMode::ForwardMode) { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + SmallVector args = {orig->getArgOperand(0)}; + CallInst *CI = Builder2.CreateCall(orig->getFunctionType(), + orig->getCalledFunction(), args); + CI->setAttributes(orig->getAttributes()); + return; } } @@ -6412,7 +6832,8 @@ class AdjointGenerator // TODO enable this if we need to free the memory // NOTE THAT TOPLEVEL IS THERE SIMPLY BECAUSE THAT WAS PREVIOUS ATTITUTE // TO FREE'ing - if (Mode != DerivativeMode::ReverseModeCombined) { + if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ReverseModePrimal) { if ((primalNeededInReverse && !gutils->unnecessaryIntermediates.count(orig)) || hasPDFree) { @@ -6425,7 +6846,8 @@ class AdjointGenerator freeKnownAllocation(Builder2, lookup(nop, Builder2), *called, dbgLoc, gutils->TLI); } - } else if (Mode != DerivativeMode::ReverseModePrimal) { + } else if (Mode == DerivativeMode::ReverseModeGradient || + Mode == DerivativeMode::ReverseModeCombined) { // Note that here we cannot simply replace with null as users who // try to find the shadow pointer will use the shadow of null rather // than the true shadow of this @@ -6435,7 +6857,7 @@ class AdjointGenerator gutils->replaceAWithB(newCall, pn); gutils->erase(newCall); } - } else { + } else if (Mode == DerivativeMode::ReverseModeCombined) { IRBuilder<> Builder2(call.getParent()); getReverseBuilder(Builder2); auto dbgLoc = gutils->getNewFromOriginal(orig)->getDebugLoc(); @@ -6485,7 +6907,8 @@ class AdjointGenerator if (!constval) { Value *val; if (Mode == DerivativeMode::ReverseModePrimal || - Mode == DerivativeMode::ReverseModeCombined) { + Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ForwardMode) { Value *ptrshadow = gutils->invertPointerM(call.getArgOperand(0), BuilderZ); BuilderZ.CreateCall( @@ -6523,7 +6946,7 @@ class AdjointGenerator // memset->addParamAttr(0, Attribute::getWithAlignment(Context, // inst->getAlignment())); memset->addParamAttr(0, Attribute::NonNull); - } else { + } else if (Mode == DerivativeMode::ReverseModeGradient) { PHINode *toReplace = BuilderZ.CreatePHI( cast(call.getArgOperand(0)->getType()) ->getElementType(), @@ -6549,7 +6972,7 @@ class AdjointGenerator // TO FREE'ing if (Mode == DerivativeMode::ReverseModeGradient) { eraseIfUnused(*orig, /*erase*/ true, /*check*/ false); - } else if (Mode != DerivativeMode::ReverseModeCombined) { + } else if (Mode == DerivativeMode::ReverseModePrimal) { // if (is_value_needed_in_reverse( // TR, gutils, orig, /*topLevel*/ Mode == // DerivativeMode::Both)) @@ -6562,7 +6985,7 @@ class AdjointGenerator // to find the shadow pointer will use the shadow of null rather than // the true shadow of this //} - } else { + } else if (Mode == DerivativeMode::ReverseModeCombined) { IRBuilder<> Builder2(newCall->getNextNode()); auto load = Builder2.CreateLoad( gutils->getNewFromOriginal(call.getOperand(0)), "posix_preread"); @@ -6584,6 +7007,22 @@ class AdjointGenerator assert(gutils->invertedPointers.find(orig) == gutils->invertedPointers.end()); + if (Mode == DerivativeMode::ForwardMode) { + if (!gutils->isConstantValue(orig->getArgOperand(0))) { + IRBuilder<> Builder2(&call); + getForwardBuilder(Builder2); + auto origfree = orig->getArgOperand(0); + auto tofree = gutils->invertPointerM(origfree, Builder2); + if (tofree != origfree) { + SmallVector args = {tofree}; + CallInst *CI = Builder2.CreateCall(orig->getFunctionType(), + orig->getCalledFunction(), args); + CI->setAttributes(orig->getAttributes()); + } + } + return; + } + if (gutils->forwardDeallocations.count(orig)) { if (Mode == DerivativeMode::ReverseModeGradient) { eraseIfUnused(*orig, /*erase*/ true, /*check*/ false); diff --git a/enzyme/test/Enzyme/ForwardMode/cosh.ll b/enzyme/test/Enzyme/ForwardMode/cosh.ll new file mode 100644 index 0000000000000..d24bc8e3494ff --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/cosh.ll @@ -0,0 +1,28 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @cosh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @cosh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @sinh(double %x) +; CHECK-NEXT: %1 = fmul fast double %"x'", %0 +; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0 +; CHECK-NEXT: ret { double } %2 +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ForwardMode/erf.ll b/enzyme/test/Enzyme/ForwardMode/erf.ll new file mode 100644 index 0000000000000..d57648da3772e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/erf.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +declare double @erf(double) + +define double @tester(double %x) { +entry: + %call = call double @erf(double %x) + ret double %call +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %"x'") { +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = {{(fsub fast double \-?0.000000e\+00,|fneg fast double)}} %0 +; CHECK-NEXT: %2 = call fast double @llvm.exp.f64(double %1) +; CHECK-NEXT: %3 = fmul fast double %2, 0x3FF20DD750429B6D +; CHECK-NEXT: %4 = fmul fast double %3, %"x'" +; CHECK-NEXT: %5 = insertvalue { double } undef, double %4, 0 +; CHECK-NEXT: ret { double } %5 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/erfc.ll b/enzyme/test/Enzyme/ForwardMode/erfc.ll new file mode 100644 index 0000000000000..6eff5249279f5 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/erfc.ll @@ -0,0 +1,29 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +declare double @erfc(double) + +define double @tester(double %x) { +entry: + %call = call double @erfc(double %x) + ret double %call +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %"x'") { +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = {{(fsub fast double \-?0.000000e\+00,|fneg fast double)}} %0 +; CHECK-NEXT: %2 = call fast double @llvm.exp.f64(double %1) +; CHECK-NEXT: %3 = fmul fast double %2, 0xBFF20DD750429B6D +; CHECK-NEXT: %4 = fmul fast double %3, %"x'" +; CHECK-NEXT: %5 = insertvalue { double } undef, double %4, 0 +; CHECK-NEXT: ret { double } %5 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/erfi.ll b/enzyme/test/Enzyme/ForwardMode/erfi.ll new file mode 100644 index 0000000000000..1c750fa068eff --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/erfi.ll @@ -0,0 +1,28 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s + +declare double @erfi(double) + +define double @tester(double %x) { +entry: + %call = call double @erfi(double %x) + ret double %call +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %"x'") { +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = call fast double @llvm.exp.f64(double %0) +; CHECK-NEXT: %2 = fmul fast double %1, 0x3FF20DD750429B6D +; CHECK-NEXT: %3 = fmul fast double %2, %"x'" +; CHECK-NEXT: %4 = insertvalue { double } undef, double %3, 0 +; CHECK-NEXT: ret { double } %4 +; CHECK-NEXT: }