diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index dbfbad868b093..95a336d51a2f0 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -7581,11 +7581,11 @@ class AdjointGenerator } } - auto newcalled = gutils->Logic.CreatePrimalAndGradient( + auto newcalled = gutils->Logic.CreateForwardDiff( cast(called), subretType, argsInverted, gutils->TLI, TR.analyzer.interprocedural, /*returnValue*/ retUsed, /*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr, - nextTypeInfo, uncacheable_args, nullptr, + nextTypeInfo, uncacheable_args, /*AtomicAdd*/ gutils->AtomicAdd); assert(newcalled); diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 21fe6aec886ac..52dc668187a3d 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -323,6 +323,29 @@ LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) { return wrap(gutils->inversionAllocs); } +LLVMValueRef EnzymeCreateForwardDiff( + EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType, + CDIFFE_TYPE *constant_args, size_t constant_args_size, + EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed, + CDerivativeMode mode, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, + uint8_t *_uncacheable_args, size_t uncacheable_args_size, uint8_t AtomicAdd, + uint8_t PostOpt) { + std::vector nconstant_args((DIFFE_TYPE *)constant_args, + (DIFFE_TYPE *)constant_args + + constant_args_size); + std::map uncacheable_args; + size_t argnum = 0; + for (auto &arg : cast(unwrap(todiff))->args()) { + assert(argnum < uncacheable_args_size); + uncacheable_args[&arg] = _uncacheable_args[argnum]; + argnum++; + } + return wrap(eunwrap(Logic).CreateForwardDiff( + cast(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args, + eunwrap(TA).TLI, eunwrap(TA), returnValue, dretUsed, (DerivativeMode)mode, + unwrap(additionalArg), eunwrap(typeInfo, cast(unwrap(todiff))), + uncacheable_args, AtomicAdd, PostOpt)); +} LLVMValueRef EnzymeCreatePrimalAndGradient( EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index bddf7bc36058c..4822ba78dd466 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -116,6 +116,14 @@ typedef enum { DEM_ReverseModeCombined = 3, } CDerivativeMode; +LLVMValueRef EnzymeCreateForwardDiff( + EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType, + CDIFFE_TYPE *constant_args, size_t constant_args_size, + EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed, + CDerivativeMode mode, LLVMTypeRef additionalArg, + struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args, + size_t uncacheable_args_size, uint8_t AtomicAdd, uint8_t PostOpt); + LLVMValueRef EnzymeCreatePrimalAndGradient( EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index c2eea1b60d1ed..3f524a6cbe306 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -482,6 +482,11 @@ class Enzyme : public ModulePass { Type *tapeType = nullptr; switch (mode) { case DerivativeMode::ForwardMode: + newFunc = Logic.CreateForwardDiff( + cast(fn), retType, constants, TLI, TA, + /*should return*/ false, /*dretPtr*/ false, mode, + /*addedType*/ nullptr, type_args, volatile_args, AtomicAdd, PostOpt); + break; case DerivativeMode::ReverseModeCombined: newFunc = Logic.CreatePrimalAndGradient( cast(fn), retType, constants, TLI, TA, diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index ee304034016f0..71886ff17c1a0 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -1387,6 +1387,70 @@ bool legalCombinedForwardReverse( return true; } +void clearFunctionAttributes(Function *f) { + for (Argument &Arg : f->args()) { + if (Arg.hasAttribute(Attribute::Returned)) + Arg.removeAttr(Attribute::Returned); + if (Arg.hasAttribute(Attribute::StructRet)) + Arg.removeAttr(Attribute::StructRet); + } + if (f->hasFnAttribute(Attribute::OptimizeNone)) + f->removeFnAttr(Attribute::OptimizeNone); + + if (auto bytes = + f->getDereferenceableBytes(llvm::AttributeList::ReturnIndex)) { + AttrBuilder ab; + ab.addDereferenceableAttr(bytes); + f->removeAttributes(llvm::AttributeList::ReturnIndex, ab); + } + + if (f->getAttributes().getRetAlignment()) { + AttrBuilder ab; + ab.addAlignmentAttr(f->getAttributes().getRetAlignment()); + f->removeAttributes(llvm::AttributeList::ReturnIndex, ab); + } + if (f->hasAttribute(llvm::AttributeList::ReturnIndex, + llvm::Attribute::NoAlias)) { + f->removeAttribute(llvm::AttributeList::ReturnIndex, + llvm::Attribute::NoAlias); + } +#if LLVM_VERSION_MAJOR >= 11 + if (f->hasAttribute(llvm::AttributeList::ReturnIndex, + llvm::Attribute::NoUndef)) { + f->removeAttribute(llvm::AttributeList::ReturnIndex, + llvm::Attribute::NoUndef); + } +#endif + if (f->hasAttribute(llvm::AttributeList::ReturnIndex, + llvm::Attribute::NonNull)) { + f->removeAttribute(llvm::AttributeList::ReturnIndex, + llvm::Attribute::NonNull); + } + if (f->hasAttribute(llvm::AttributeList::ReturnIndex, + llvm::Attribute::ZExt)) { + f->removeAttribute(llvm::AttributeList::ReturnIndex, llvm::Attribute::ZExt); + } +} + +void cleanupInversionAllocs(DiffeGradientUtils *gutils, BasicBlock *entry) { + while (gutils->inversionAllocs->size() > 0) { + Instruction *inst = &gutils->inversionAllocs->back(); + if (isa(inst)) + inst->moveBefore(&gutils->newFunc->getEntryBlock().front()); + else + inst->moveBefore(entry->getFirstNonPHIOrDbgOrLifetime()); + } + + (IRBuilder<>(gutils->inversionAllocs)).CreateUnreachable(); + DeleteDeadBlock(gutils->inversionAllocs); + for (auto BBs : gutils->reverseBlocks) { + if (pred_begin(BBs.second.front()) == pred_end(BBs.second.front())) { + (IRBuilder<>(BBs.second.front())).CreateUnreachable(); + DeleteDeadBlock(BBs.second.front()); + } + } +} + //! return structtype if recursive function const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( Function *todiff, DIFFE_TYPE retType, @@ -2068,38 +2132,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( } } - for (Argument &Arg : NewF->args()) { - if (Arg.hasAttribute(Attribute::Returned)) - Arg.removeAttr(Attribute::Returned); - if (Arg.hasAttribute(Attribute::StructRet)) - Arg.removeAttr(Attribute::StructRet); - } - if (NewF->hasFnAttribute(Attribute::OptimizeNone)) - NewF->removeFnAttr(Attribute::OptimizeNone); - - if (auto bytes = - NewF->getDereferenceableBytes(llvm::AttributeList::ReturnIndex)) { - AttrBuilder ab; - ab.addDereferenceableAttr(bytes); - NewF->removeAttributes(llvm::AttributeList::ReturnIndex, ab); - } - if (NewF->hasAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NoAlias)) { - NewF->removeAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NoAlias); - } -#if LLVM_VERSION_MAJOR >= 11 - if (NewF->hasAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NoUndef)) { - NewF->removeAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NoUndef); - } -#endif - if (NewF->hasAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::ZExt)) { - NewF->removeAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::ZExt); - } + clearFunctionAttributes(NewF); if (llvm::verifyFunction(*NewF, &llvm::errs())) { llvm::errs() << *gutils->oldFunc << "\n"; @@ -2603,8 +2636,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( bool omp) { assert(mode == DerivativeMode::ReverseModeCombined || - mode == DerivativeMode::ReverseModeGradient || - mode == DerivativeMode::ForwardMode); + mode == DerivativeMode::ReverseModeGradient); FnTypeInfo oldTypeInfo = oldTypeInfo_; for (auto &pair : oldTypeInfo.KnownValues) { @@ -3311,67 +3343,282 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } } - while (gutils->inversionAllocs->size() > 0) { - Instruction *inst = &gutils->inversionAllocs->back(); - if (isa(inst)) - inst->moveBefore(&gutils->newFunc->getEntryBlock().front()); - else - inst->moveBefore(entry->getFirstNonPHIOrDbgOrLifetime()); + cleanupInversionAllocs(gutils, entry); + clearFunctionAttributes(gutils->newFunc); + + if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { + llvm::errs() << *gutils->oldFunc << "\n"; + llvm::errs() << *gutils->newFunc << "\n"; + report_fatal_error("function failed verification (4)"); } - (IRBuilder<>(gutils->inversionAllocs)).CreateUnreachable(); - DeleteDeadBlock(gutils->inversionAllocs); - for (auto BBs : gutils->reverseBlocks) { - if (pred_begin(BBs.second.front()) == pred_end(BBs.second.front())) { - (IRBuilder<>(BBs.second.front())).CreateUnreachable(); - DeleteDeadBlock(BBs.second.front()); + auto nf = gutils->newFunc; + delete gutils; + + { + PreservedAnalyses PA; + PPC.FAM.invalidate(*nf, PA); + } + PPC.AlwaysInline(nf); + if (Arch == Triple::nvptx || Arch == Triple::nvptx64) + PPC.ReplaceReallocs(nf, /*mem2reg*/ true); + + if (PostOpt) + PPC.optimizeIntermediate(nf); + if (EnzymePrint) { + llvm::errs() << *nf << "\n"; + } + return nf; +} + +Function *EnzymeLogic::CreateForwardDiff( + Function *todiff, DIFFE_TYPE retType, + const std::vector &constant_args, TargetLibraryInfo &TLI, + TypeAnalysis &TA, bool returnUsed, bool dretPtr, DerivativeMode mode, + llvm::Type *additionalArg, const FnTypeInfo &oldTypeInfo_, + const std::map _uncacheable_args, bool AtomicAdd, + bool PostOpt, bool omp) { + + assert(mode == DerivativeMode::ForwardMode); + + FnTypeInfo oldTypeInfo = oldTypeInfo_; + for (auto &pair : oldTypeInfo.KnownValues) { + if (pair.second.size() != 0) { + bool recursiveUse = false; + for (auto user : pair.first->users()) { + if (auto bi = dyn_cast(user)) { + for (auto biuser : bi->users()) { + if (auto ci = dyn_cast(biuser)) { + if (ci->getCalledFunction() == todiff && + ci->getArgOperand(pair.first->getArgNo()) == bi) { + recursiveUse = true; + break; + } + } + } + } + if (recursiveUse) + break; + } + if (recursiveUse) { + pair.second.clear(); + } } } - for (Argument &Arg : gutils->newFunc->args()) { - if (Arg.hasAttribute(Attribute::Returned)) - Arg.removeAttr(Attribute::Returned); - if (Arg.hasAttribute(Attribute::StructRet)) - Arg.removeAttr(Attribute::StructRet); + if (retType != DIFFE_TYPE::CONSTANT) + assert(!todiff->getReturnType()->isVoidTy()); + + ReverseCacheKey tup = + std::make_tuple(todiff, retType, constant_args, + std::map(_uncacheable_args.begin(), + _uncacheable_args.end()), + returnUsed, dretPtr, mode, additionalArg, oldTypeInfo); + if (ReverseCachedFunctions.find(tup) != ReverseCachedFunctions.end()) { + return ReverseCachedFunctions.find(tup)->second; } - if (gutils->newFunc->hasFnAttribute(Attribute::OptimizeNone)) - gutils->newFunc->removeFnAttr(Attribute::OptimizeNone); - if (auto bytes = gutils->newFunc->getDereferenceableBytes( - llvm::AttributeList::ReturnIndex)) { - AttrBuilder ab; - ab.addDereferenceableAttr(bytes); - gutils->newFunc->removeAttributes(llvm::AttributeList::ReturnIndex, ab); + // Whether we shuold actually return the value + bool returnValue = returnUsed; + + // TODO change this to go by default function type assumptions + bool hasconstant = false; + for (auto v : constant_args) { + if (v == DIFFE_TYPE::CONSTANT) { + hasconstant = true; + break; + } } - if (gutils->newFunc->getAttributes().getRetAlignment()) { - AttrBuilder ab; - ab.addAlignmentAttr(gutils->newFunc->getAttributes().getRetAlignment()); - gutils->newFunc->removeAttributes(llvm::AttributeList::ReturnIndex, ab); + assert(!todiff->empty()); + + auto TRo = TA.analyzeFunction(oldTypeInfo); + bool retActive = TRo.getReturnAnalysis().Inner0().isPossibleFloat() && + !todiff->getReturnType()->isVoidTy(); + + ReturnType retVal = + returnValue ? (retActive ? ReturnType::TwoReturns : ReturnType::Return) + : (retActive ? ReturnType::Return : ReturnType::Void); + + bool diffeReturnArg = false; + + DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone( + *this, mode, todiff, TLI, TA, retType, diffeReturnArg, constant_args, + retVal, additionalArg, omp); + + gutils->AtomicAdd = AtomicAdd; + insert_or_assign2(ReverseCachedFunctions, tup, + gutils->newFunc); + + const SmallPtrSet guaranteedUnreachable = + getGuaranteedUnreachable(gutils->oldFunc); + + // Convert uncacheable args from the input function to the preprocessed + // function + std::map _uncacheable_argsPP; + { + auto in_arg = todiff->arg_begin(); + auto pp_arg = gutils->oldFunc->arg_begin(); + for (; pp_arg != gutils->oldFunc->arg_end();) { + _uncacheable_argsPP[pp_arg] = _uncacheable_args.find(in_arg)->second; + ++pp_arg; + ++in_arg; + } } - if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NoAlias)) { - gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NoAlias); + + FnTypeInfo typeInfo(gutils->oldFunc); + { + auto toarg = todiff->arg_begin(); + auto olarg = gutils->oldFunc->arg_begin(); + for (; toarg != todiff->arg_end(); ++toarg, ++olarg) { + + { + auto fd = oldTypeInfo.Arguments.find(toarg); + assert(fd != oldTypeInfo.Arguments.end()); + typeInfo.Arguments.insert( + std::pair(olarg, fd->second)); + } + + { + auto cfd = oldTypeInfo.KnownValues.find(toarg); + assert(cfd != oldTypeInfo.KnownValues.end()); + typeInfo.KnownValues.insert( + std::pair>(olarg, cfd->second)); + } + } + typeInfo.Return = oldTypeInfo.Return; } -#if LLVM_VERSION_MAJOR >= 11 - if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NoUndef)) { - gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NoUndef); + + TypeResults TR = TA.analyzeFunction(typeInfo); + assert(TR.getFunction() == gutils->oldFunc); + + gutils->forceActiveDetection(TR); + gutils->forceAugmentedReturns(TR, guaranteedUnreachable); + + gutils->computeGuaranteedFrees(guaranteedUnreachable); + + // TODO populate with actual unnecessaryInstructions once the dependency + // cycle with activity analysis is removed + SmallPtrSet unnecessaryInstructionsTmp; + for (auto BB : guaranteedUnreachable) { + for (auto &I : *BB) + unnecessaryInstructionsTmp.insert(&I); } -#endif - if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NonNull)) { - gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::NonNull); + CacheAnalysis CA(gutils->allocationsWithGuaranteedFree, TR, gutils->OrigAA, + gutils->oldFunc, + PPC.FAM.getResult(*gutils->oldFunc), + gutils->OrigLI, gutils->OrigDT, TLI, + unnecessaryInstructionsTmp, _uncacheable_argsPP, mode, omp); + const std::map> + uncacheable_args_map = CA.compute_uncacheable_args_for_callsites(); + + const std::map can_modref_map = + CA.compute_uncacheable_load_map(); + gutils->can_modref_map = &can_modref_map; + + std::map, int> mapping; + + auto getIndex = [&](Instruction *I, CacheType u) -> unsigned { + return gutils->getIndex(std::make_pair(I, u), mapping); + }; + + gutils->computeMinCache(TR, guaranteedUnreachable); + + SmallPtrSet unnecessaryValues; + SmallPtrSet unnecessaryInstructions; + calculateUnusedValuesInFunction( + *gutils->oldFunc, unnecessaryValues, unnecessaryInstructions, returnValue, + mode, TR, gutils, TLI, constant_args, guaranteedUnreachable); + + SmallPtrSet unnecessaryStores; + calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, + unnecessaryInstructions, gutils); + + // set derivative of function arguments + auto newArgs = gutils->newFunc->arg_begin(); + + for (size_t i = 0; i < constant_args.size(); ++i) { + auto arg = constant_args[i]; + if (arg == DIFFE_TYPE::DUP_ARG) { + newArgs += 1; + auto pri = gutils->oldFunc->arg_begin() + i; + auto dif = newArgs; + + BasicBlock &BB = gutils->newFunc->getEntryBlock(); + IRBuilder<> Builder(&BB.front()); + + gutils->setDiffe(pri, dif, Builder); + } + newArgs += 1; } - if (gutils->newFunc->hasAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::ZExt)) { - gutils->newFunc->removeAttribute(llvm::AttributeList::ReturnIndex, - llvm::Attribute::ZExt); + + AdjointGenerator maker( + mode, gutils, constant_args, retType, TR, getIndex, uncacheable_args_map, + /*returnuses*/ nullptr, nullptr, nullptr, unnecessaryValues, + unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, + nullptr); + + for (BasicBlock &oBB : *gutils->oldFunc) { + // Don't create derivatives for code that results in termination + if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) { + auto newBB = cast(gutils->getNewFromOriginal(&oBB)); + std::vector toRemove; + if (auto II = dyn_cast(oBB.getTerminator())) { + toRemove.push_back( + cast(gutils->getNewFromOriginal(II->getNormalDest()))); + } else { + for (auto next : successors(&oBB)) { + auto sucBB = cast(gutils->getNewFromOriginal(next)); + toRemove.push_back(sucBB); + } + } + + for (auto sucBB : toRemove) { + sucBB->removePredecessor(newBB); + } + + std::vector toerase; + for (auto &I : oBB) { + toerase.push_back(&I); + } + for (auto I : toerase) { + maker.eraseIfUnused(*I, /*erase*/ true, /*check*/ true); + } + if (newBB->getTerminator()) + newBB->getTerminator()->eraseFromParent(); + IRBuilder<> builder(newBB); + builder.CreateUnreachable(); + continue; + } + + auto term = oBB.getTerminator(); + assert(term); + if (!isa(term) && !isa(term) && + !isa(term)) { + llvm::errs() << *oBB.getParent() << "\n"; + llvm::errs() << "unknown terminator instance " << *term << "\n"; + assert(0 && "unknown terminator inst"); + } + + auto first = oBB.begin(); + auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end()); + for (auto it = first; it != last; ++it) { + maker.visit(&*it); + } + + createTerminator(gutils, &oBB, retType, retVal); } + gutils->eraseFictiousPHIs(); + + BasicBlock *entry = &gutils->newFunc->getEntryBlock(); + + auto Arch = + llvm::Triple(gutils->newFunc->getParent()->getTargetTriple()).getArch(); + + cleanupInversionAllocs(gutils, entry); + clearFunctionAttributes(gutils->newFunc); + if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { llvm::errs() << *gutils->oldFunc << "\n"; llvm::errs() << *gutils->newFunc << "\n"; diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 151774dcec523..564b5cf2e7a4e 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -189,6 +189,15 @@ class EnzymeLogic { const AugmentedReturn *augmented, bool AtomicAdd, bool PostOpt = false, bool omp = false); + llvm::Function * + CreateForwardDiff(llvm::Function *todiff, DIFFE_TYPE retType, + const std::vector &constant_args, + llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, + bool returnValue, bool dretUsed, DerivativeMode mode, + llvm::Type *additionalArg, const FnTypeInfo &typeInfo, + const std::map _uncacheable_args, + bool AtomicAdd, bool PostOpt = false, bool omp = false); + void clear(); };