Skip to content

Commit

Permalink
EnzymeLogic cleanup (rust-lang#320)
Browse files Browse the repository at this point in the history
* rename CreateDual to CreateForwardDiff

* add clearFunctionAttributes
  • Loading branch information
tgymnich authored Sep 21, 2021
1 parent bea9d39 commit 788cc8e
Show file tree
Hide file tree
Showing 6 changed files with 374 additions and 82 deletions.
4 changes: 2 additions & 2 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7581,11 +7581,11 @@ class AdjointGenerator
}
}

auto newcalled = gutils->Logic.CreatePrimalAndGradient(
auto newcalled = gutils->Logic.CreateForwardDiff(
cast<Function>(called), subretType, argsInverted, gutils->TLI,
TR.analyzer.interprocedural, /*returnValue*/ retUsed,
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
nextTypeInfo, uncacheable_args, nullptr,
nextTypeInfo, uncacheable_args,
/*AtomicAdd*/ gutils->AtomicAdd);

assert(newcalled);
Expand Down
23 changes: 23 additions & 0 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,29 @@ LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) {
return wrap(gutils->inversionAllocs);
}

LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
CDerivativeMode mode, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
uint8_t *_uncacheable_args, size_t uncacheable_args_size, uint8_t AtomicAdd,
uint8_t PostOpt) {
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
(DIFFE_TYPE *)constant_args +
constant_args_size);
std::map<llvm::Argument *, bool> uncacheable_args;
size_t argnum = 0;
for (auto &arg : cast<Function>(unwrap(todiff))->args()) {
assert(argnum < uncacheable_args_size);
uncacheable_args[&arg] = _uncacheable_args[argnum];
argnum++;
}
return wrap(eunwrap(Logic).CreateForwardDiff(
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
eunwrap(TA).TLI, eunwrap(TA), returnValue, dretUsed, (DerivativeMode)mode,
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
uncacheable_args, AtomicAdd, PostOpt));
}
LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
Expand Down
8 changes: 8 additions & 0 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ typedef enum {
DEM_ReverseModeCombined = 3,
} CDerivativeMode;

LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
CDerivativeMode mode, LLVMTypeRef additionalArg,
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
size_t uncacheable_args_size, uint8_t AtomicAdd, uint8_t PostOpt);

LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,11 @@ class Enzyme : public ModulePass {
Type *tapeType = nullptr;
switch (mode) {
case DerivativeMode::ForwardMode:
newFunc = Logic.CreateForwardDiff(
cast<Function>(fn), retType, constants, TLI, TA,
/*should return*/ false, /*dretPtr*/ false, mode,
/*addedType*/ nullptr, type_args, volatile_args, AtomicAdd, PostOpt);
break;
case DerivativeMode::ReverseModeCombined:
newFunc = Logic.CreatePrimalAndGradient(
cast<Function>(fn), retType, constants, TLI, TA,
Expand Down
Loading

0 comments on commit 788cc8e

Please sign in to comment.