diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 8a914775d438f..7604252578c10 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1031,6 +1031,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) { // abide by those rules if (!isCertainPrintMallocOrFree(called) && called->empty() && !hasMetadata(called, "enzyme_gradient") && + !hasMetadata(called, "enzyme_derivative") && !isa(op) && EnzymeEmptyFnInactive) { InsertConstantValue(TR, Val); insertConstantsFrom(TR, *UpHypothesis); @@ -1666,8 +1667,9 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults &TR, // If requesting empty unknown functions to be considered inactive, abide // by those rules if (!isCertainPrintMallocOrFree(called) && called->empty() && - !hasMetadata(called, "enzyme_gradient") && !isa(op) && - EnzymeEmptyFnInactive) { + !hasMetadata(called, "enzyme_gradient") && + !hasMetadata(called, "enzyme_derivative") && + !isa(op) && EnzymeEmptyFnInactive) { if (EnzymePrintActivity) llvm::errs() << "constant(" << (int)directions << ") up-emptyconst " << *inst << "\n"; diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index f611bf1aff970..5598c312a9db8 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -79,6 +79,220 @@ llvm::cl::opt namespace { +template +static void +handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, + std::vector &globalsToErase) { + if (g.hasInitializer()) { + if (auto CA = dyn_cast(g.getInitializer())) { + if (CA->getNumOperands() != numargs) { + llvm::errs() << M << "\n"; + llvm::errs() << "Use of " << handlername + << " must be a " + "constant of size " + << numargs << " " << g << "\n"; + llvm_unreachable(handlername); + } else { + Function *Fs[numargs]; + for (size_t i = 0; i < numargs; i++) { + Value *V = CA->getOperand(i); + while (auto CE = dyn_cast(V)) { + V = CE->getOperand(0); + } + if (auto CA = dyn_cast(V)) + V = CA->getOperand(0); + while (auto CE = dyn_cast(V)) { + V = CE->getOperand(0); + } + if (auto F = dyn_cast(V)) { + Fs[i] = F; + } else { + llvm::errs() << M << "\n"; + llvm::errs() << "Param of " << handlername + << " must be a " + "function" + << g << "\n" + << *V << "\n"; + llvm_unreachable(handlername); + } + } + + if (numargs == 3) { + Fs[0]->setMetadata( + "enzyme_augment", + llvm::MDTuple::get(Fs[0]->getContext(), + {llvm::ValueAsMetadata::get(Fs[1])})); + Fs[0]->setMetadata( + "enzyme_gradient", + llvm::MDTuple::get(Fs[0]->getContext(), + {llvm::ValueAsMetadata::get(Fs[2])})); + } else if (numargs == 2) { + Fs[0]->setMetadata( + "enzyme_derivative", + llvm::MDTuple::get(Fs[0]->getContext(), + {llvm::ValueAsMetadata::get(Fs[1])})); + } + } + } else { + llvm::errs() << M << "\n"; + llvm::errs() << "Use of " << handlername + << " must be a " + "constant aggregate " + << g << "\n"; + llvm_unreachable(handlername); + } + } else { + llvm::errs() << M << "\n"; + llvm::errs() << "Use of " << handlername + << " must be a " + "constant array of size " + << numargs << " " << g << "\n"; + llvm_unreachable(handlername); + } + globalsToErase.push_back(&g); +} + +static void +handleInactiveFunction(llvm::Module &M, llvm::GlobalVariable &g, + std::vector &globalsToErase) { + if (g.hasInitializer()) { + Value *V = g.getInitializer(); + while (auto CE = dyn_cast(V)) { + V = CE->getOperand(0); + } + if (auto CA = dyn_cast(V)) + V = CA->getOperand(0); + while (auto CE = dyn_cast(V)) { + V = CE->getOperand(0); + } + if (auto F = dyn_cast(V)) { + F->addAttribute(AttributeList::FunctionIndex, + Attribute::get(g.getContext(), "enzyme_inactive")); + } else { + llvm::errs() << M << "\n"; + llvm::errs() << "Param of __enzyme_inactivefn must be a " + "function" + << g << "\n" + << *V << "\n"; + llvm_unreachable("__enzyme_inactivefn"); + } + } else { + llvm::errs() << M << "\n"; + llvm::errs() << "Use of __enzyme_inactivefn must be a " + "constant function " + << g << "\n"; + llvm_unreachable("__enzyme_register_gradient"); + } + globalsToErase.push_back(&g); +} + +static void handleKnownFunctions(llvm::Function &F) { + if (F.getName() == "MPI_Irecv") { + F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); + F.addFnAttr(Attribute::NoUnwind); + F.addFnAttr(Attribute::NoRecurse); +#if LLVM_VERSION_MAJOR >= 9 + F.addFnAttr(Attribute::WillReturn); + F.addFnAttr(Attribute::NoFree); + F.addFnAttr(Attribute::NoSync); +#endif + F.addParamAttr(0, Attribute::WriteOnly); + if (F.getFunctionType()->getParamType(2)->isPointerTy()) { + F.addParamAttr(2, Attribute::NoCapture); + F.addParamAttr(2, Attribute::WriteOnly); + } + F.addParamAttr(6, Attribute::WriteOnly); + } + if (F.getName() == "MPI_Isend") { + F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); + F.addFnAttr(Attribute::NoUnwind); + F.addFnAttr(Attribute::NoRecurse); +#if LLVM_VERSION_MAJOR >= 9 + F.addFnAttr(Attribute::WillReturn); + F.addFnAttr(Attribute::NoFree); + F.addFnAttr(Attribute::NoSync); +#endif + F.addParamAttr(0, Attribute::WriteOnly); + if (F.getFunctionType()->getParamType(2)->isPointerTy()) { + F.addParamAttr(2, Attribute::NoCapture); + F.addParamAttr(2, Attribute::ReadOnly); + } + F.addParamAttr(6, Attribute::WriteOnly); + } + if (F.getName() == "MPI_Comm_rank") { + F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); + F.addFnAttr(Attribute::NoUnwind); + F.addFnAttr(Attribute::NoRecurse); +#if LLVM_VERSION_MAJOR >= 9 + F.addFnAttr(Attribute::WillReturn); + F.addFnAttr(Attribute::NoFree); + F.addFnAttr(Attribute::NoSync); +#endif + if (F.getFunctionType()->getParamType(0)->isPointerTy()) { + F.addParamAttr(0, Attribute::NoCapture); + F.addParamAttr(0, Attribute::ReadOnly); + } + F.addParamAttr(1, Attribute::WriteOnly); + F.addParamAttr(1, Attribute::NoCapture); + } + if (F.getName() == "MPI_Wait") { + F.addFnAttr(Attribute::NoUnwind); + F.addFnAttr(Attribute::NoRecurse); +#if LLVM_VERSION_MAJOR >= 9 + F.addFnAttr(Attribute::WillReturn); + F.addFnAttr(Attribute::NoFree); + F.addFnAttr(Attribute::NoSync); +#endif + F.addParamAttr(0, Attribute::ReadOnly); + F.addParamAttr(0, Attribute::NoCapture); + F.addParamAttr(1, Attribute::WriteOnly); + F.addParamAttr(1, Attribute::NoCapture); + } + if (F.getName() == "MPI_Waitall") { + F.addFnAttr(Attribute::NoUnwind); + F.addFnAttr(Attribute::NoRecurse); +#if LLVM_VERSION_MAJOR >= 9 + F.addFnAttr(Attribute::WillReturn); + F.addFnAttr(Attribute::NoFree); + F.addFnAttr(Attribute::NoSync); +#endif + F.addParamAttr(1, Attribute::ReadOnly); + F.addParamAttr(1, Attribute::NoCapture); + F.addParamAttr(2, Attribute::WriteOnly); + F.addParamAttr(2, Attribute::NoCapture); + } + if (F.getName() == "omp_get_max_threads" || + F.getName() == "omp_get_thread_num") { + F.addFnAttr(Attribute::ReadOnly); + F.addFnAttr(Attribute::InaccessibleMemOnly); + } + if (F.getName() == "frexp" || F.getName() == "frexpf" || + F.getName() == "frexpl") { + F.addFnAttr(Attribute::ArgMemOnly); + F.addParamAttr(1, Attribute::WriteOnly); + } + if (F.getName() == "__fd_sincos_1" || F.getName() == "__fd_cos_1" || + F.getName() == "__mth_i_ipowi") { + F.addFnAttr(Attribute::ReadNone); + } +} + +static void handleAnnotations(llvm::Function &F) { + if (F.getName().contains("__enzyme_float") || + F.getName().contains("__enzyme_double") || + F.getName().contains("__enzyme_integer") || + F.getName().contains("__enzyme_pointer") || + F.getName().contains("__enzyme_virtualreverse")) { + F.addFnAttr(Attribute::ReadNone); + for (auto &arg : F.args()) { + if (arg.getType()->isPointerTy()) { + arg.addAttr(Attribute::ReadNone); + arg.addAttr(Attribute::NoCapture); + } + } + } +} + class Enzyme : public ModulePass { public: EnzymeLogic Logic; @@ -1416,205 +1630,31 @@ class Enzyme : public ModulePass { } bool runOnModule(Module &M) override { + constexpr static const char gradient_handler_name[] = + "__enzyme_register_gradient"; + constexpr static const char derivative_handler_name[] = + "__enzyme_register_derivative"; + Logic.clear(); bool changed = false; std::vector globalsToErase; for (GlobalVariable &g : M.globals()) { - if (g.getName().contains("__enzyme_register_gradient")) { - if (g.hasInitializer()) { - if (auto CA = dyn_cast(g.getInitializer())) { - if (CA->getNumOperands() != 3) { - llvm::errs() << M << "\n"; - llvm::errs() << "Use of __enzyme_register_gradient must be a " - "constant of size 3 " - << g << "\n"; - llvm_unreachable("__enzyme_register_gradient"); - } else { - Function *Fs[3]; - for (size_t i = 0; i < 3; i++) { - Value *V = CA->getOperand(i); - while (auto CE = dyn_cast(V)) { - V = CE->getOperand(0); - } - if (auto CA = dyn_cast(V)) - V = CA->getOperand(0); - while (auto CE = dyn_cast(V)) { - V = CE->getOperand(0); - } - if (auto F = dyn_cast(V)) { - Fs[i] = F; - } else { - llvm::errs() << M << "\n"; - llvm::errs() - << "Param of __enzyme_register_gradient must be a " - "function" - << g << "\n" - << *V << "\n"; - llvm_unreachable("__enzyme_register_gradient"); - } - } - Fs[0]->setMetadata( - "enzyme_augment", - llvm::MDTuple::get(Fs[0]->getContext(), - {llvm::ValueAsMetadata::get(Fs[1])})); - Fs[0]->setMetadata( - "enzyme_gradient", - llvm::MDTuple::get(Fs[0]->getContext(), - {llvm::ValueAsMetadata::get(Fs[2])})); - } - } else { - llvm::errs() << M << "\n"; - llvm::errs() << "Use of __enzyme_register_gradient must be a " - "constant aggregate " - << g << "\n"; - llvm_unreachable("__enzyme_register_gradient"); - } - } else { - llvm::errs() << M << "\n"; - llvm::errs() << "Use of __enzyme_register_gradient must be a " - "constant array of size 3 " - << g << "\n"; - llvm_unreachable("__enzyme_register_gradient"); - } - globalsToErase.push_back(&g); + if (g.getName().contains(gradient_handler_name)) { + handleCustomDerivative(M, g, globalsToErase); + } else if (g.getName().contains(derivative_handler_name)) { + handleCustomDerivative(M, g, + globalsToErase); } else if (g.getName().contains("__enzyme_inactivefn")) { - if (g.hasInitializer()) { - Value *V = g.getInitializer(); - while (auto CE = dyn_cast(V)) { - V = CE->getOperand(0); - } - if (auto CA = dyn_cast(V)) - V = CA->getOperand(0); - while (auto CE = dyn_cast(V)) { - V = CE->getOperand(0); - } - if (auto F = dyn_cast(V)) { - F->addAttribute(AttributeList::FunctionIndex, - Attribute::get(g.getContext(), "enzyme_inactive")); - } else { - llvm::errs() << M << "\n"; - llvm::errs() << "Param of __enzyme_inactivefn must be a " - "function" - << g << "\n" - << *V << "\n"; - llvm_unreachable("__enzyme_inactivefn"); - } - } else { - llvm::errs() << M << "\n"; - llvm::errs() << "Use of __enzyme_inactivefn must be a " - "constant function " - << g << "\n"; - llvm_unreachable("__enzyme_register_gradient"); - } - globalsToErase.push_back(&g); + handleInactiveFunction(M, g, globalsToErase); } } for (auto g : globalsToErase) { g->eraseFromParent(); } for (Function &F : M) { - if (F.getName().contains("__enzyme_float") || - F.getName().contains("__enzyme_double") || - F.getName().contains("__enzyme_integer") || - F.getName().contains("__enzyme_pointer") || - F.getName().contains("__enzyme_virtualreverse")) { - F.addFnAttr(Attribute::ReadNone); - for (auto &arg : F.args()) { - if (arg.getType()->isPointerTy()) { - arg.addAttr(Attribute::ReadNone); - arg.addAttr(Attribute::NoCapture); - } - } - } - if (F.getName() == "MPI_Irecv") { - F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); -#if LLVM_VERSION_MAJOR >= 9 - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); -#endif - F.addParamAttr(0, Attribute::WriteOnly); - if (F.getFunctionType()->getParamType(2)->isPointerTy()) { - F.addParamAttr(2, Attribute::NoCapture); - F.addParamAttr(2, Attribute::WriteOnly); - } - F.addParamAttr(6, Attribute::WriteOnly); - } - if (F.getName() == "MPI_Isend") { - F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); -#if LLVM_VERSION_MAJOR >= 9 - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); -#endif - F.addParamAttr(0, Attribute::WriteOnly); - if (F.getFunctionType()->getParamType(2)->isPointerTy()) { - F.addParamAttr(2, Attribute::NoCapture); - F.addParamAttr(2, Attribute::ReadOnly); - } - F.addParamAttr(6, Attribute::WriteOnly); - } - if (F.getName() == "MPI_Comm_rank") { - F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); -#if LLVM_VERSION_MAJOR >= 9 - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); -#endif - if (F.getFunctionType()->getParamType(0)->isPointerTy()) { - F.addParamAttr(0, Attribute::NoCapture); - F.addParamAttr(0, Attribute::ReadOnly); - } - F.addParamAttr(1, Attribute::WriteOnly); - F.addParamAttr(1, Attribute::NoCapture); - } - if (F.getName() == "MPI_Wait") { - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); -#if LLVM_VERSION_MAJOR >= 9 - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); -#endif - F.addParamAttr(0, Attribute::ReadOnly); - F.addParamAttr(0, Attribute::NoCapture); - F.addParamAttr(1, Attribute::WriteOnly); - F.addParamAttr(1, Attribute::NoCapture); - } - if (F.getName() == "MPI_Waitall") { - F.addFnAttr(Attribute::NoUnwind); - F.addFnAttr(Attribute::NoRecurse); -#if LLVM_VERSION_MAJOR >= 9 - F.addFnAttr(Attribute::WillReturn); - F.addFnAttr(Attribute::NoFree); - F.addFnAttr(Attribute::NoSync); -#endif - F.addParamAttr(1, Attribute::ReadOnly); - F.addParamAttr(1, Attribute::NoCapture); - F.addParamAttr(2, Attribute::WriteOnly); - F.addParamAttr(2, Attribute::NoCapture); - } - if (F.getName() == "omp_get_max_threads" || - F.getName() == "omp_get_thread_num") { - F.addFnAttr(Attribute::ReadOnly); - F.addFnAttr(Attribute::InaccessibleMemOnly); - } - if (F.getName() == "frexp" || F.getName() == "frexpf" || - F.getName() == "frexpl") { - F.addFnAttr(Attribute::ArgMemOnly); - F.addParamAttr(1, Attribute::WriteOnly); - } - if (F.getName() == "__fd_sincos_1" || F.getName() == "__fd_cos_1" || - F.getName() == "__mth_i_ipowi") { - F.addFnAttr(Attribute::ReadNone); - } + handleAnnotations(F); + handleKnownFunctions(F); if (F.empty()) continue; std::vector toErase; diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index ae901c99855d1..97b1d4c1e2940 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3681,6 +3681,22 @@ Function *EnzymeLogic::CreateForwardDiff( assert(!todiff->empty()); + if (hasMetadata(todiff, "enzyme_derivative")) { + auto md = todiff->getMetadata("enzyme_derivative"); + if (!isa(md)) { + llvm::errs() << *todiff << "\n"; + llvm::errs() << *md << "\n"; + report_fatal_error( + "unknown derivative for function -- metadata incorrect"); + } + auto md2 = cast(md); + assert(md2->getNumOperands() == 1); + auto gvemd = cast(md2->getOperand(0)); + auto foundcalled = cast(gvemd->getValue()); + + return foundcalled; + } + auto TRo = TA.analyzeFunction(oldTypeInfo); bool retActive = TRo.getReturnAnalysis().Inner0().isPossibleFloat() && !todiff->getReturnType()->isVoidTy(); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 4604bee9dc253..5d1182f4d8e16 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4016,7 +4016,8 @@ void TypeAnalyzer::visitCallInst(CallInst &call) { updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1), &call); } - if (!ci->empty() && !hasMetadata(ci, "enzyme_gradient")) { + if (!ci->empty() && !hasMetadata(ci, "enzyme_gradient") && + !hasMetadata(ci, "enzyme_derivative")) { visitIPOCall(call, *ci); } } diff --git a/enzyme/test/Integration/ForwardMode/customfwd.c b/enzyme/test/Integration/ForwardMode/customfwd.c new file mode 100644 index 0000000000000..839e6578fdf4a --- /dev/null +++ b/enzyme/test/Integration/ForwardMode/customfwd.c @@ -0,0 +1,49 @@ + +// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli - +// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli - + +#include "test_utils.h" + +double __enzyme_fwddiff(void*, ...); + +__attribute__((noinline)) +void square_(const double* src, double* dest) { + *dest = *src * *src; +} + +int derivative = 0; +void derivative_square_(const double* src, const double *d_src, const double* dest, double* d_dest) { + derivative++; + // intentionally incorrect for debugging + *d_dest = 100; +} + +void* __enzyme_register_derivative_square[] = { + (void*)square_, + (void*)derivative_square_, +}; + + +double square(double x) { + double y; + square_(&x, &y); + return y; +} + +double dsquare(double x) { + return __enzyme_fwddiff((void*)square, x, 1.0); +} + + +int main() { + double res = dsquare(3.0); + printf("res=%f derivative=%d\n", res, derivative); + APPROX_EQ(res, 100.0, 1e-10); + APPROX_EQ(derivative, 1.0, 1e-10); +} \ No newline at end of file