Skip to content

Commit

Permalink
Keep integer extract (rust-lang#914)
Browse files Browse the repository at this point in the history
* Keep integer extract

* Fixup
  • Loading branch information
wsmoses authored Oct 24, 2022
1 parent 2a8ba2d commit 235fc1d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 52 deletions.
84 changes: 49 additions & 35 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,52 @@ class AdjointGenerator
bool constantval = gutils->isConstantValue(orig_val) ||
parseTBAA(I, DL).Inner0().isIntegral();

// TODO allow recognition of other types that could contain pointers [e.g.
// {void*, void*} or <2 x i64> ]
auto storeSize = DL.getTypeSizeInBits(valType) / 8;

auto vd = TR.query(orig_ptr).Lookup(storeSize, DL);

if (!vd.isKnown()) {
if (looseTypeAnalysis || true) {
vd = defaultTypeTreeForLLVM(valType, &I);
EmitWarning("CannotDeduceType", I, "failed to deduce type of xtore ",
I);
goto known;
}
if (CustomErrorHandler) {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce type of store " << I;
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer);
}
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I,
"failed to deduce type of store ", I);

TR.intType(storeSize, orig_ptr, /*errifnotfound*/ true,
/*pointerIntSame*/ true);
llvm_unreachable("bad mti");
known:;
}

auto dt = vd[{-1}];
for (size_t i = 0; i < storeSize; ++i) {
bool Legal = true;
dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
if (!Legal) {
if (CustomErrorHandler) {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce single type of store " << I;
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer);
}
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I,
"failed to deduce single type of store ", I);
}
}

if (Mode == DerivativeMode::ForwardMode) {
IRBuilder<> Builder2(&I);
getForwardBuilder(Builder2);
Expand All @@ -1140,7 +1186,8 @@ class AdjointGenerator
// TODO type analyze
if (!constantval)
diff = gutils->invertPointerM(orig_val, Builder2, /*nullShadow*/ true);
else if (orig_val->getType()->isPointerTy())
else if (orig_val->getType()->isPointerTy() || dt == BaseType::Pointer ||
dt == BaseType::Integer)
diff = gutils->invertPointerM(orig_val, Builder2, /*nullShadow*/ false);
else
diff = gutils->invertPointerM(orig_val, Builder2, /*nullShadow*/ true);
Expand All @@ -1150,41 +1197,8 @@ class AdjointGenerator
return;
}

// TODO allow recognition of other types that could contain pointers [e.g.
// {void*, void*} or <2 x i64> ]
auto storeSize = DL.getTypeSizeInBits(valType) / 8;

//! Storing a floating point value
Type *FT = nullptr;
if (valType->isFPOrFPVectorTy()) {
FT = valType->getScalarType();
} else if (!valType->isPointerTy()) {
auto fp =
TR.firstPointer(storeSize, orig_ptr, &I, /*errifnotfound*/ false,
/*pointerIntSame*/ true);
if (fp.isKnown()) {
FT = fp.isFloat();
} else if (looseTypeAnalysis && (isa<ConstantInt>(orig_val) ||
valType->isIntOrIntVectorTy())) {
llvm::errs() << "assuming type as integral for store: " << I << "\n";
FT = nullptr;
} else {

if (CustomErrorHandler) {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce type of store " << I;
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer);
}
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I,
"failed to deduce type of store ", I);
TR.firstPointer(storeSize, orig_ptr, &I, /*errifnotfound*/ true,
/*pointerIntSame*/ true);
}
}

if (FT) {
if (Type *FT = dt.isFloat()) {
//! Only need to update the reverse function
switch (Mode) {
case DerivativeMode::ReverseModePrimal:
Expand Down
20 changes: 14 additions & 6 deletions enzyme/Enzyme/DifferentialUseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,10 @@ static inline bool is_value_needed_in_reverse(
}

if (isa<ReturnInst>(user)) {
if (gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_ARG ||
gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_NONEED) {
if ((gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_ARG ||
gutils->ATA->ActiveReturns == DIFFE_TYPE::DUP_NONEED) &&
((inst_cv && VT == ValueType::Primal) ||
(!inst_cv && VT == ValueType::Shadow))) {
if (EnzymePrintDiffUse)
llvm::errs() << " Need: " << to_string(VT) << " of " << *inst
<< " in reverse as shadow return " << *user << "\n";
Expand Down Expand Up @@ -755,10 +757,10 @@ static inline bool is_value_needed_in_reverse(
if (user->getType()->isVoidTy())
goto endShadow;

if (!TR.query(const_cast<Instruction *>(user))
.Inner0()
.isPossiblePointer())
if (!TR.query(const_cast<Instruction *>(user))[{-1}]
.isPossiblePointer()) {
goto endShadow;
}

if (!OneLevel && is_value_needed_in_reverse<ValueType::Shadow>(
gutils, user, mode, seen, oldUnreachable)) {
Expand Down Expand Up @@ -884,14 +886,20 @@ static inline bool is_value_needed_in_reverse(
bool valueIsIndex = false;
for (unsigned i = 2; i < IVI->getNumOperands(); ++i) {
if (IVI->getOperand(i) == inst) {
if (inst == IVI->getInsertedValueOperand() &&
TR.query(
const_cast<Value *>(IVI->getInsertedValueOperand()))[{-1}]
.isFloat()) {
continue;
}
valueIsIndex = true;
}
}
primalUsedInShadowPointer = valueIsIndex;
}
if (auto EVI = dyn_cast<ExtractValueInst>(user)) {
bool valueIsIndex = false;
for (unsigned i = 2; i < EVI->getNumOperands(); ++i) {
for (unsigned i = 1; i < EVI->getNumOperands(); ++i) {
if (EVI->getOperand(i) == inst) {
valueIsIndex = true;
}
Expand Down
11 changes: 3 additions & 8 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2828,11 +2828,7 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
} else if (!gutils->isConstantValue(ret)) {
toret = gutils->diffe(ret, nBuilder);
} else {
IRBuilder<> eB(gutils->inversionAllocs);
Type *retTy = gutils->getShadowType(ret->getType());
auto al = eB.CreateAlloca(retTy);
ZeroMemory(eB, retTy, al, /*isTape*/ false);
toret = nBuilder.CreateLoad(al);
toret = gutils->invertPointerM(ret, nBuilder, /*nullInit*/ true);
}

break;
Expand All @@ -2853,9 +2849,8 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
toret =
nBuilder.CreateInsertValue(toret, gutils->diffe(ret, nBuilder), 1);
} else {
Type *retTy = gutils->getShadowType(ret->getType());
toret =
nBuilder.CreateInsertValue(toret, Constant::getNullValue(retTy), 1);
toret = nBuilder.CreateInsertValue(
toret, gutils->invertPointerM(ret, nBuilder, /*nullInit*/ true), 1);
}
break;
}
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4060,7 +4060,8 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
return applyChainRule(oval->getType(), BuilderM, rule);
}

if (isConstantValue(oval)) {
if (isConstantValue(oval) && !isa<InsertValueInst>(oval) &&
!isa<ExtractValueInst>(oval)) {
// NOTE, this is legal and the correct resolution, however, our activity
// analysis honeypot no longer exists

Expand All @@ -4084,7 +4085,6 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,

return applyChainRule(oval->getType(), BuilderM, rule);
}
assert(!isConstantValue(oval));

auto M = oldFunc->getParent();
assert(oval);
Expand Down Expand Up @@ -4477,7 +4477,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
goto end;
} else if (auto arg = dyn_cast<ExtractValueInst>(oval)) {
IRBuilder<> bb(getNewFromOriginal(arg));
auto ip = invertPointerM(arg->getOperand(0), bb);
auto ip = invertPointerM(arg->getOperand(0), bb, nullShadow);

auto rule = [&bb, &arg](Value *ip) {
return bb.CreateExtractValue(ip, arg->getIndices(),
Expand Down

0 comments on commit 235fc1d

Please sign in to comment.